List of usage examples for weka.core Instance classValue
public double classValue();
From source file:machinelearning_cw.KNN.java
@Override public double classifyInstance(Instance instance) throws Exception { // Check that classifier has been trained if (trainingData == null) { throw new Exception("Classifier has not been trained." + " No call to buildClassifier() was made"); }//from ww w . j av a2 s. c o m if (useStandardisedAttributes) { if (!isMeanAndStdDevInitialised) { // throw exception } else { /* Standardise test instance */ for (int i = 0; i < instance.numAttributes() - 1; i++) { double value = instance.value(i); double standardisedValue = (value - mean[i]) / standardDeviation[i]; instance.setValue(i, standardisedValue); } } } if (!useWeightedVoting) { return super.classifyInstance(instance); } else { if (!useAcceleratedNNSearch) { /* Calculate euclidean distances */ double[] distances = Helpers.findEuclideanDistances(trainingData, instance); /* * Create a list of dictionaries where each dictionary contains * the keys "distance", "weight" and "id". * The distance key stores the euclidean distance for an * instance and the id key stores the hashcode for that * instance object. */ ArrayList<HashMap<String, Object>> table = Helpers.buildDistanceTable(trainingData, distances); /* Find the k smallest distances */ Object[] kClosestRows = new Object[k]; Object[] kClosestInstances = new Object[k]; double[] classValues = new double[k]; for (int i = 1; i <= k; i++) { ArrayList<Integer> tieIndices = new ArrayList<Integer>(); /* Find the positions in the table of the ith closest * neighbour. */ int[] closestRowIndices = this.findNthClosestNeighbourByWeights(table, i); if (closestRowIndices.length > 0) { /* Keep track of distance ties */ for (int j = 0; j < closestRowIndices.length; j++) { tieIndices.add(closestRowIndices[j]); } /* Break ties (by choosing winner at random) */ Random rand = new Random(); int matchingNeighbourPosition = tieIndices.get(rand.nextInt(tieIndices.size())); HashMap<String, Object> matchingRow = table.get(matchingNeighbourPosition); kClosestRows[i - 1] = matchingRow; } } /* * Find the closestInstances from their rows in the table and * also get their class values. */ for (int i = 0; i < kClosestRows.length; i++) { /* Build up closestInstances array */ for (int j = 0; j < trainingData.numInstances(); j++) { Instance inst = trainingData.get(j); HashMap<String, Object> row = (HashMap<String, Object>) kClosestRows[i]; if (Integer.toHexString(inst.hashCode()).equals(row.get("id"))) { kClosestInstances[i] = inst; } } } /* Vote by weights */ /* Get max class value */ double[] possibleClassValues = trainingData.attributeToDoubleArray(trainingData.classIndex()); int maxClassIndex = Utils.maxIndex(possibleClassValues); double maxClassValue = possibleClassValues[maxClassIndex]; ArrayList<Double> weightedVotes = new ArrayList<Double>(); /* Calculate the sum of votes for each class */ for (double i = 0; i <= maxClassValue; i++) { double weightCount = 0; /* Calculate sum */ for (int j = 0; j < kClosestInstances.length; j++) { Instance candidateInstance = (Instance) kClosestInstances[j]; if (candidateInstance.classValue() == i) { // Get weight HashMap<String, Object> row = (HashMap<String, Object>) kClosestRows[(int) j]; weightCount += (double) row.get("weight"); } } weightedVotes.add(weightCount); } /* Select instance with highest vote */ Double[] votesArray = new Double[weightedVotes.size()]; weightedVotes.toArray(votesArray); double greatestSoFar = votesArray[0]; int greatestIndex = 0; for (int i = 0; i < votesArray.length; i++) { if (votesArray[i] > greatestSoFar) { greatestSoFar = votesArray[i]; greatestIndex = i; } } /* * Class value will be the index because classes are indexed * from 0 upwards. */ return greatestIndex; } /* Use Orchards algorithm to accelerate NN search */ else { // find k nearest neighbours ArrayList<Instance> nearestNeighbours = new ArrayList<Instance>(); for (int i = 0; i < k; i++) { nearestNeighbours.add(findNthClosestWithOrchards(instance, trainingData, i)); } // Find their class values double[] classValues = new double[nearestNeighbours.size()]; for (int i = 0; i < nearestNeighbours.size(); i++) { classValues[i] = nearestNeighbours.get(i).classValue(); } return Helpers.mode(Helpers.arrayToArrayList(classValues)); } } }
From source file:machinelearning_cw.KNN.java
/** * /*from www .j a v a 2 s.c om*/ * Estimate the accuracy of using a value as k by applying * Leave-One-Out-Cross-Validation(LOOCV). * * @param k value of k to be tested. * @param distanceMatrix A matrix containing the precomputed distances of * every instance of the training data from each other. * * @return Accuracy of the calling classifier using the given value of k. * @throws Exception */ private double estimateAccuracyByLOOCV(int k, ArrayList<ArrayList<HashMap<String, Object>>> distanceMatrix) throws Exception { ArrayList<Double> accuracies = new ArrayList<Double>(); int i = 0; for (ArrayList<HashMap<String, Object>> distancesRow : distanceMatrix) { Instance testInstance = trainingData.get(i); /* * For each test instance, classify by choosing the 2nd to the * (k+1)st in the sorted list. */ Instance[] closestDistances = new Instance[k]; double[] closestClassValues = new double[k]; for (int j = 0; j < closestDistances.length; j++) { closestDistances[j] = (Instance) distancesRow.get(j).get("trainingInstance"); closestClassValues[j] = closestDistances[j].classValue(); } /* Calculate accuracy */ double predictedClass = Helpers.mode(Helpers.arrayToArrayList(closestClassValues)); double actualClass = testInstance.classValue(); double accuracy = 0; if (predictedClass == actualClass) { accuracy = 1.0; } accuracies.add(accuracy); i++; } /* find average accuracy */ double count = accuracies.size(); double sum = 0; for (Double eachAccuracy : accuracies) { sum += eachAccuracy; } double averageAccuracy = sum / count; return averageAccuracy; }
From source file:machine_learing_clasifier.MyC45.java
public double BestContinousAttribute(Instances i, Attribute att) { i.sort(att);/* w w w. j a v a 2 s . c o m*/ Enumeration enumForMissingAttr = i.enumerateInstances(); double temp = i.get(0).classValue(); double igtemp = 0; double bestthreshold = 0; double a; double b = i.get(0).value(att); while (enumForMissingAttr.hasMoreElements()) { Instance inst = (Instance) enumForMissingAttr.nextElement(); if (temp != inst.classValue()) { temp = inst.classValue(); a = b; b = inst.value(att); double threshold = a + ((b - a) / 2); double igtemp2 = computeInformationGainContinous(i, att, threshold); if (igtemp < igtemp2) { bestthreshold = threshold; igtemp = igtemp2; } } } return bestthreshold; }
From source file:main.NaiveBayes.java
License:Open Source License
/** * Updates the classifier with the given instance. * /*ww w . j a v a 2 s .c o m*/ * @param instance the new training instance to include in the model * @exception Exception if the instance could not be incorporated in the * model. */ public void updateClassifier(Instance instance) throws Exception { if (!instance.classIsMissing()) { Enumeration<Attribute> enumAtts = m_Instances.enumerateAttributes(); int attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (!instance.isMissing(attribute)) { m_Distributions[attIndex][(int) instance.classValue()].addValue(instance.value(attribute), instance.weight()); } attIndex++; } m_ClassDistribution.addValue(instance.classValue(), instance.weight()); } }
From source file:milk.classifiers.MINND.java
License:Open Source License
/** * Pre-process the given exemplar according to the other exemplars * in the given exemplars. It also updates noise data statistics. * * @param data the whole exemplars// w w w.j a va2s .c om * @param pos the position of given exemplar in data * @return the processed exemplar * @exception if the returned exemplar is wrong */ public Exemplar preprocess(Exemplars data, int pos) throws Exception { Exemplar before = data.exemplar(pos); if ((int) before.classValue() == 0) { m_NoiseM[pos] = null; m_NoiseV[pos] = null; return before; } Exemplar after = new Exemplar(before, 0); Exemplar noises = new Exemplar(before, 0); for (int g = 0; g < before.getInstances().numInstances(); g++) { Instance datum = before.getInstances().instance(g); double[] dists = new double[data.numExemplars()]; for (int i = 0; i < data.numExemplars(); i++) { if (i != pos) dists[i] = distance(datum, m_Mean[i], m_Variance[i], i); else dists[i] = Double.POSITIVE_INFINITY; } int[] pred = new int[m_NumClasses]; for (int n = 0; n < pred.length; n++) pred[n] = 0; for (int o = 0; o < m_Select; o++) { int index = Utils.minIndex(dists); pred[(int) m_Class[index]]++; dists[index] = Double.POSITIVE_INFINITY; } int clas = Utils.maxIndex(pred); if ((int) datum.classValue() != clas) noises.add(datum); else after.add(datum); } if (Utils.gr(noises.getInstances().sumOfWeights(), 0)) { m_NoiseM[pos] = noises.meanOrMode(); m_NoiseV[pos] = noises.variance(); for (int y = 0; y < m_NoiseV[pos].length; y++) { if (Utils.eq(m_NoiseV[pos][y], 0.0)) m_NoiseV[pos][y] = m_ZERO; } } else { m_NoiseM[pos] = null; m_NoiseV[pos] = null; } return after; }
From source file:milk.core.Exemplar.java
License:Open Source License
/** * Constructor using one instance to form an exemplar * //from w w w. ja v a 2 s. co m * @param instance the given instance * @param id the ID index */ public Exemplar(Instance inst, int id) { m_IdIndex = id; m_IdValue = inst.value(id); m_ClassIndex = inst.classIndex(); m_ClassValue = inst.classValue(); m_Instances = new Instances(inst.dataset(), 1); m_Instances.add(inst); }
From source file:milk.visualize.MIPlot2D.java
License:Open Source License
/** * Renders this component/*ww w. j a v a 2 s.c om*/ * @param gx the graphics context */ public void paintComponent(Graphics gx) { //if(!isEnabled()) // return; super.paintComponent(gx); if (plotExemplars != null) { gx.setColor(m_axisColour); // Draw the axis name String xname = plotExemplars.attribute(m_xIndex).name(), yname = plotExemplars.attribute(m_yIndex).name(); gx.drawString(yname, m_XaxisStart + m_labelMetrics.stringWidth("M"), m_YaxisStart + m_labelMetrics.getAscent() / 2 + m_tickSize); gx.drawString(xname, m_XaxisEnd - m_labelMetrics.stringWidth(yname) + m_tickSize, (int) (m_YaxisEnd - m_labelMetrics.getAscent() / 2)); // Draw points Attribute classAtt = plotExemplars.classAttribute(); for (int j = 0; j < m_plots.size(); j++) { PlotData2D temp_plot = (PlotData2D) (m_plots.elementAt(j)); Instances instances = temp_plot.getPlotInstances(); StringTokenizer st = new StringTokenizer( instances.firstInstance().stringValue(plotExemplars.idIndex()), "_"); //////////////////// TLD stuff ///////////////////////////////// /* double[] mu = new double[plotExemplars.numAttributes()], sgm = new double[plotExemplars.numAttributes()]; st.nextToken(); // Squeeze first element int p=0; while(p<mu.length){ if((p==plotExemplars.idIndex()) || (p==plotExemplars.classIndex())) p++; if(p<mu.length){ mu[p] = Double.parseDouble(st.nextToken()); sgm[p] = Double.parseDouble(st.nextToken()); p++; } } Instance ins = instances.firstInstance(); gx.setColor((Color)m_colorList.elementAt((int)ins.classValue())); double mux=mu[m_xIndex], muy=mu[m_yIndex], sgmx=sgm[m_xIndex], sgmy=sgm[m_yIndex]; double xs = convertToPanelX(mux-3*sgmx), xe = convertToPanelX(mux+3*sgmx), xleng = Math.abs(xe-xs); double ys = convertToPanelY(muy+3*sgmy), ye = convertToPanelY(muy-3*sgmy), yleng = Math.abs(ye-ys); // Draw oval gx.drawOval((int)xs,(int)ys,(int)xleng,(int)yleng); // Draw a dot gx.fillOval((int)convertToPanelX(mux)-2, (int)convertToPanelY(muy)-2, 4, 4); */ //////////////////// TLD stuff ///////////////////////////////// //////////////////// instance-based stuff ///////////////////////////////// /* double[] core = new double[plotExemplars.numAttributes()], range=new double[plotExemplars.numAttributes()]; st.nextToken(); // Squeeze first element int p=0; while(p<range.length){ if((p==plotExemplars.idIndex()) || (p==plotExemplars.classIndex())) p++; if(p<range.length) range[p++] = Double.parseDouble(st.nextToken()); } p=0; while(st.hasMoreTokens()){ if((p==plotExemplars.idIndex()) || (p==plotExemplars.classIndex())) p++; core[p++] = Double.parseDouble(st.nextToken()); } Instance ins = instances.firstInstance(); gx.setColor((Color)m_colorList.elementAt((int)ins.classValue())); double rgx=range[m_xIndex], rgy=range[m_yIndex]; double x1 = convertToPanelX(core[m_xIndex]-rgx/2), y1 = convertToPanelY(core[m_yIndex]-rgy/2), x2 = convertToPanelX(core[m_xIndex]+rgx/2), y2 = convertToPanelY(core[m_yIndex]+rgy/2), x = convertToPanelX(core[m_xIndex]), y = convertToPanelY(core[m_yIndex]); // Draw a rectangle gx.drawLine((int)x1, (int)y1, (int)x2, (int)y1); gx.drawLine((int)x1, (int)y1, (int)x1, (int)y2); gx.drawLine((int)x2, (int)y1, (int)x2, (int)y2); gx.drawLine((int)x1, (int)y2, (int)x2, (int)y2); // Draw a dot gx.fillOval((int)x-3, (int)y-3, 6, 6); // Draw string StringBuffer text =new StringBuffer(temp_plot.getPlotName()+":"+instances.numInstances()); gx.drawString(text.toString(), (int)x1, (int)y2+m_labelMetrics.getHeight()); */ //////////////////// instance-based stuff ///////////////////////////////// //////////////////// normal graph ///////////////////////////////// // Paint numbers for (int i = 0; i < instances.numInstances(); i++) { Instance ins = instances.instance(i); if (!ins.isMissing(m_xIndex) && !ins.isMissing(m_yIndex)) { if (classAtt.isNominal()) gx.setColor((Color) m_colorList.elementAt((int) ins.classValue())); else { double r = (ins.classValue() - m_minC) / (m_maxC - m_minC); r = (r * 240) + 15; gx.setColor(new Color((int) r, 150, (int) (255 - r))); } double x = convertToPanelX(ins.value(m_xIndex)); double y = convertToPanelY(ins.value(m_yIndex)); String id = temp_plot.getPlotName(); gx.drawString(id, (int) (x - m_labelMetrics.stringWidth(id) / 2), (int) (y + m_labelMetrics.getHeight() / 2)); } } //////////////////// normal graph ///////////////////////////////// } } //////////////////// TLD stuff ///////////////////////////////// // Draw two Guassian contour with 3 stdDev // (-1, -1) with stdDev 1, 2 // (1, 1) with stdDev 2, 1 /*gx.setColor(Color.black); double mu=-1.5, sigmx, sigmy; // class 0 if(m_xIndex == 1) sigmx = 1; else sigmx = 2; if(m_yIndex == 1) sigmy = 1; else sigmy = 2; double x1 = convertToPanelX(mu-3*sigmx), x2 = convertToPanelX(mu+3*sigmx), xlen = Math.abs(x2-x1); double y1 = convertToPanelY(mu+3*sigmy), y2 = convertToPanelY(mu-3*sigmy), ylen = Math.abs(y2-y1); // Draw heavy oval gx.drawOval((int)x1,(int)y1,(int)xlen,(int)ylen); gx.drawOval((int)x1-1,(int)y1-1,(int)xlen+2,(int)ylen+2); gx.drawOval((int)x1+1,(int)y1+1,(int)xlen-2,(int)ylen-2); // Draw a dot gx.fillOval((int)convertToPanelX(mu)-3, (int)convertToPanelY(mu)-3, 6, 6); mu=1.5; // class 1 if(m_xIndex == 1) sigmx = 1; else sigmx = 2; if(m_yIndex == 1) sigmy = 1; else sigmy = 2; x1 = convertToPanelX(mu-3*sigmx); x2 = convertToPanelX(mu+3*sigmx); xlen = Math.abs(x2-x1); y1 = convertToPanelY(mu+3*sigmy); y2 = convertToPanelY(mu-3*sigmy); ylen = Math.abs(y2-y1); // Draw heavy oval gx.drawOval((int)x1,(int)y1,(int)xlen,(int)ylen); gx.drawOval((int)x1-1,(int)y1-1,(int)xlen+2,(int)ylen+2); gx.drawOval((int)x1+1,(int)y1+1,(int)xlen-2,(int)ylen-2); // Draw a dot gx.fillOval((int)convertToPanelX(mu)-3, (int)convertToPanelY(mu)-3, 6, 6); */ //////////////////// TLD stuff ///////////////////////////////// //////////////////// instance-based stuff ///////////////////////////////// /* // Paint a log-odds line: 1*x0+2*x1=0 double xstart, xend, ystart, yend, xCoeff, yCoeff; if(m_xIndex == 1) xCoeff = 1; else xCoeff = 2; if(m_yIndex == 1) yCoeff = 1; else yCoeff = 2; xstart = m_minX; ystart = -xstart*xCoeff/yCoeff; if(ystart > m_maxY){ ystart = m_maxY; xstart = -ystart*yCoeff/xCoeff; } yend = m_minY; xend = -yend*yCoeff/xCoeff; if(xend > m_maxX){ xend = m_maxX; yend = -xend*xCoeff/yCoeff; } // Draw a heavy line gx.setColor(Color.black); gx.drawLine((int)convertToPanelX(xstart), (int)convertToPanelY(ystart), (int)convertToPanelX(xend), (int)convertToPanelY(yend)); gx.drawLine((int)convertToPanelX(xstart)+1, (int)convertToPanelY(ystart)+1, (int)convertToPanelX(xend)+1, (int)convertToPanelY(yend)+1); gx.drawLine((int)convertToPanelX(xstart)-1, (int)convertToPanelY(ystart)-1, (int)convertToPanelX(xend)-1, (int)convertToPanelY(yend)-1); */ //////////////////// instance-based stuff ///////////////////////////////// }
From source file:ml.ann.BackPropagation.java
License:Open Source License
private void initInputAndTarget(Instance instance) { int classAttributeIdx = neuronTopology.instances.classIndex(); if (neuronTopology.instances.classAttribute().isNumeric()) { neuronTopology.target = new double[1]; neuronTopology.target[0] = instance.value(classAttributeIdx); } else if (neuronTopology.instances.classAttribute().isNominal()) { neuronTopology.target = new double[instance.numClasses()]; for (int i = 0; i < instance.numClasses(); i++) { neuronTopology.target[i] = 0; }/* ww w. j a va2s. c om*/ int idxClassValue = (int) instance.classValue(); neuronTopology.target[idxClassValue] = 1; } for (int i = 0; i < instance.numAttributes(); i++) { if (i == 0) { neuronTopology.input[i] = 1; neuronTopology.output[0][i] = 1; } else { neuronTopology.input[i] = instance.value(i - 1); neuronTopology.output[0][i] = instance.value(i - 1); } } // System.out.println(neuronTopology.originInstances.instance(instancesIdx).toString()); }
From source file:ml.ann.MultiClassPTR.java
@Override public void buildClassifier(Instances instances) throws Exception { initAttributes(instances);/* w ww .j av a2 s . com*/ // REMEMBER: only works if class index is in the last position for (int instanceIdx = 0; instanceIdx < instances.numInstances(); instanceIdx++) { Instance instance = instances.get(instanceIdx); double[] inputInstance = inputInstances[instanceIdx]; inputInstance[0] = 1.0; // initialize bias value for (int attrIdx = 0; attrIdx < instance.numAttributes() - 1; attrIdx++) { inputInstance[attrIdx + 1] = instance.value(attrIdx); // the first index of input instance is for bias } } // Initialize target values if (instances.classAttribute().isNominal()) { for (int instanceIdx = 0; instanceIdx < instances.numInstances(); instanceIdx++) { Instance instance = instances.instance(instanceIdx); for (int classIdx = 0; classIdx < instances.numClasses(); classIdx++) { targetInstances[instanceIdx][classIdx] = 0.0; } targetInstances[instanceIdx][(int) instance.classValue()] = 1.0; } } else { for (int instanceIdx = 0; instanceIdx < instances.numInstances(); instanceIdx++) { Instance instance = instances.instance(instanceIdx); targetInstances[instanceIdx][0] = instance.classValue(); } } if (algo == 1) { setActFunction(); buildClassifier(); } else if (algo == 2) { buildClassifier(); } else if (algo == 3) { buildClassifierBatch(); } }
From source file:ml.dataprocess.CorrelationAttributeEval.java
License:Open Source License
/** * Initializes an information gain attribute evaluator. Replaces missing * values with means/modes; Deletes instances with missing class values. * /*from w w w. j ava2 s . c o m*/ * @param data set of instances serving as training data * @throws Exception if the evaluator has not been generated successfully */ @Override public void buildEvaluator(Instances data) throws Exception { data = new Instances(data); data.deleteWithMissingClass(); ReplaceMissingValues rmv = new ReplaceMissingValues(); rmv.setInputFormat(data); data = Filter.useFilter(data, rmv); int numClasses = data.classAttribute().numValues(); int classIndex = data.classIndex(); int numInstances = data.numInstances(); m_correlations = new double[data.numAttributes()]; /* * boolean hasNominals = false; boolean hasNumerics = false; */ List<Integer> numericIndexes = new ArrayList<Integer>(); List<Integer> nominalIndexes = new ArrayList<Integer>(); if (m_detailedOutput) { m_detailedOutputBuff = new StringBuffer(); } // TODO for instance weights (folded into computing weighted correlations) // add another dimension just before the last [2] (0 for 0/1 binary vector // and // 1 for corresponding instance weights for the 1's) double[][][] nomAtts = new double[data.numAttributes()][][]; for (int i = 0; i < data.numAttributes(); i++) { if (data.attribute(i).isNominal() && i != classIndex) { nomAtts[i] = new double[data.attribute(i).numValues()][data.numInstances()]; Arrays.fill(nomAtts[i][0], 1.0); // set zero index for this att to all // 1's nominalIndexes.add(i); } else if (data.attribute(i).isNumeric() && i != classIndex) { numericIndexes.add(i); } } // do the nominal attributes if (nominalIndexes.size() > 0) { for (int i = 0; i < data.numInstances(); i++) { Instance current = data.instance(i); for (int j = 0; j < current.numValues(); j++) { if (current.attribute(current.index(j)).isNominal() && current.index(j) != classIndex) { // Will need to check for zero in case this isn't a sparse // instance (unless we add 1 and subtract 1) nomAtts[current.index(j)][(int) current.valueSparse(j)][i] += 1; nomAtts[current.index(j)][0][i] -= 1; } } } } if (data.classAttribute().isNumeric()) { double[] classVals = data.attributeToDoubleArray(classIndex); // do the numeric attributes for (Integer i : numericIndexes) { double[] numAttVals = data.attributeToDoubleArray(i); m_correlations[i] = Utils.correlation(numAttVals, classVals, numAttVals.length); if (m_correlations[i] == 1.0) { // check for zero variance (useless numeric attribute) if (Utils.variance(numAttVals) == 0) { m_correlations[i] = 0; } } } // do the nominal attributes if (nominalIndexes.size() > 0) { // now compute the correlations for the binarized nominal attributes for (Integer i : nominalIndexes) { double sum = 0; double corr = 0; double sumCorr = 0; double sumForValue = 0; if (m_detailedOutput) { m_detailedOutputBuff.append("\n\n").append(data.attribute(i).name()); } for (int j = 0; j < data.attribute(i).numValues(); j++) { sumForValue = Utils.sum(nomAtts[i][j]); corr = Utils.correlation(nomAtts[i][j], classVals, classVals.length); // useless attribute - all instances have the same value if (sumForValue == numInstances || sumForValue == 0) { corr = 0; } if (corr < 0.0) { corr = -corr; } sumCorr += sumForValue * corr; sum += sumForValue; if (m_detailedOutput) { m_detailedOutputBuff.append("\n\t").append(data.attribute(i).value(j)).append(": "); m_detailedOutputBuff.append(Utils.doubleToString(corr, 6)); } } m_correlations[i] = (sum > 0) ? sumCorr / sum : 0; } } } else { // class is nominal // TODO extra dimension for storing instance weights too double[][] binarizedClasses = new double[data.classAttribute().numValues()][data.numInstances()]; // this is equal to the number of instances for all inst weights = 1 double[] classValCounts = new double[data.classAttribute().numValues()]; for (int i = 0; i < data.numInstances(); i++) { Instance current = data.instance(i); binarizedClasses[(int) current.classValue()][i] = 1; } for (int i = 0; i < data.classAttribute().numValues(); i++) { classValCounts[i] = Utils.sum(binarizedClasses[i]); } double sumClass = Utils.sum(classValCounts); // do numeric attributes first if (numericIndexes.size() > 0) { for (Integer i : numericIndexes) { double[] numAttVals = data.attributeToDoubleArray(i); double corr = 0; double sumCorr = 0; for (int j = 0; j < data.classAttribute().numValues(); j++) { corr = Utils.correlation(numAttVals, binarizedClasses[j], numAttVals.length); if (corr < 0.0) { corr = -corr; } if (corr == 1.0) { // check for zero variance (useless numeric attribute) if (Utils.variance(numAttVals) == 0) { corr = 0; } } sumCorr += classValCounts[j] * corr; } m_correlations[i] = sumCorr / sumClass; } } if (nominalIndexes.size() > 0) { for (Integer i : nominalIndexes) { if (m_detailedOutput) { m_detailedOutputBuff.append("\n\n").append(data.attribute(i).name()); } double sumForAtt = 0; double corrForAtt = 0; for (int j = 0; j < data.attribute(i).numValues(); j++) { double sumForValue = Utils.sum(nomAtts[i][j]); double corr = 0; double sumCorr = 0; double avgCorrForValue = 0; sumForAtt += sumForValue; for (int k = 0; k < numClasses; k++) { // corr between value j and class k corr = Utils.correlation(nomAtts[i][j], binarizedClasses[k], binarizedClasses[k].length); // useless attribute - all instances have the same value if (sumForValue == numInstances || sumForValue == 0) { corr = 0; } if (corr < 0.0) { corr = -corr; } sumCorr += classValCounts[k] * corr; } avgCorrForValue = sumCorr / sumClass; corrForAtt += sumForValue * avgCorrForValue; if (m_detailedOutput) { m_detailedOutputBuff.append("\n\t").append(data.attribute(i).value(j)).append(": "); m_detailedOutputBuff.append(Utils.doubleToString(avgCorrForValue, 6)); } } // the weighted average corr for att i as // a whole (wighted by value frequencies) m_correlations[i] = (sumForAtt > 0) ? corrForAtt / sumForAtt : 0; } } } if (m_detailedOutputBuff != null && m_detailedOutputBuff.length() > 0) { m_detailedOutputBuff.append("\n"); } }