List of usage examples for weka.filters.unsupervised.attribute Remove setAttributeIndicesArray
public void setAttributeIndicesArray(int[] attributes)
From source file:WrapperSubset.java
License:Open Source License
/** * Evaluates a subset of attributes// ww w. j av a2 s.c o m * * @param subset a bitset representing the attribute subset to be evaluated * @return the error rate * @throws Exception if the subset could not be evaluated */ @Override public double evaluateSubset(BitSet subset) throws Exception { // if (subset.isEmpty()) // return 0.0; double evalMetric = 0; double[] repError = new double[5]; int numAttributes = 0; int i, j; Random Rnd = new Random(m_seed); Remove delTransform = new Remove(); delTransform.setInvertSelection(true); // copy the instances Instances trainCopy = new Instances(m_trainInstances); // count attributes set in the BitSet for (i = 0; i < m_numAttribs; i++) { if (subset.get(i)) { numAttributes++; } } // set up an array of attribute indexes for the filter (+1 for the class) int[] featArray = new int[numAttributes + 1]; for (i = 0, j = 0; i < m_numAttribs; i++) { if (subset.get(i)) { featArray[j++] = i; } } featArray[j] = m_classIndex; delTransform.setAttributeIndicesArray(featArray); delTransform.setInputFormat(trainCopy); trainCopy = Filter.useFilter(trainCopy, delTransform); // max of 5 repetitions of cross validation for (i = 0; i < 5; i++) { m_Evaluation = new Evaluation(trainCopy); m_Evaluation.crossValidateModel(m_BaseClassifier, trainCopy, m_folds, Rnd); switch (m_evaluationMeasure) { case EVAL_DEFAULT: repError[i] = m_Evaluation.errorRate(); // if (m_trainInstances.classAttribute().isNominal()) { // repError[i] = 1.0 - repError[i]; // } break; case EVAL_ACCURACY: repError[i] = m_Evaluation.errorRate(); // if (m_trainInstances.classAttribute().isNominal()) { // repError[i] = 1.0 - repError[i]; // } break; case EVAL_RMSE: repError[i] = m_Evaluation.rootMeanSquaredError(); break; case EVAL_MAE: repError[i] = m_Evaluation.meanAbsoluteError(); break; case EVAL_FMEASURE: if (m_IRClassVal < 0) { repError[i] = m_Evaluation.weightedFMeasure(); } else { repError[i] = m_Evaluation.fMeasure(m_IRClassVal); } break; case EVAL_AUC: if (m_IRClassVal < 0) { repError[i] = m_Evaluation.weightedAreaUnderROC(); } else { repError[i] = m_Evaluation.areaUnderROC(m_IRClassVal); } break; case EVAL_AUPRC: if (m_IRClassVal < 0) { repError[i] = m_Evaluation.weightedAreaUnderPRC(); } else { repError[i] = m_Evaluation.areaUnderPRC(m_IRClassVal); } break; case EVAL_NEW: repError[i] = (1.0 - m_Evaluation.errorRate()) + m_IRfactor * m_Evaluation.weightedFMeasure(); break; } // check on the standard deviation if (!repeat(repError, i + 1)) { i++; break; } } for (j = 0; j < i; j++) { evalMetric += repError[j]; } evalMetric /= i; m_Evaluation = null; switch (m_evaluationMeasure) { case EVAL_DEFAULT: case EVAL_ACCURACY: case EVAL_RMSE: case EVAL_MAE: if (m_trainInstances.classAttribute().isNominal() && (m_evaluationMeasure == EVAL_DEFAULT || m_evaluationMeasure == EVAL_ACCURACY)) { evalMetric = 1 - evalMetric; } else { evalMetric = -evalMetric; // maximize } break; } return evalMetric; }
From source file:adams.data.instancesanalysis.FastICA.java
License:Open Source License
/** * Performs the actual analysis.// w ww . j av a2 s. co m * * @param data the data to analyze * @return null if successful, otherwise error message * @throws Exception if analysis fails */ @Override protected String doAnalyze(Instances data) throws Exception { String result; Matrix matrix; Remove remove; result = null; m_Components = null; m_Sources = null; if (!m_AttributeRange.isAllRange()) { if (isLoggingEnabled()) getLogger().info("Filtering attribute range: " + m_AttributeRange.getRange()); remove = new Remove(); remove.setAttributeIndicesArray(m_AttributeRange.getIntIndices()); remove.setInvertSelection(true); remove.setInputFormat(data); data = Filter.useFilter(data, remove); } if (isLoggingEnabled()) getLogger().info("Performing ICA..."); matrix = m_ICA.transform(MatrixHelper.wekaToMatrixAlgo(MatrixHelper.getAll(data))); if (matrix != null) { m_Components = MatrixHelper.matrixToSpreadSheet(MatrixHelper.matrixAlgoToWeka(m_ICA.getComponents()), "Component-"); m_Sources = MatrixHelper.matrixToSpreadSheet(MatrixHelper.matrixAlgoToWeka(m_ICA.getSources()), "Source-"); } else { result = "Failed to transform data!"; } return result; }
From source file:adams.data.instancesanalysis.PCA.java
License:Open Source License
/** * Performs the actual analysis.//from w ww. jav a 2s.c o m * * @param data the data to analyze * @return null if successful, otherwise error message * @throws Exception if analysis fails */ @Override protected String doAnalyze(Instances data) throws Exception { String result; Remove remove; PublicPrincipalComponents pca; int i; Capabilities caps; PartitionedMultiFilter2 part; Range rangeUnsupported; Range rangeSupported; TIntList listNominal; Range rangeNominal; ArrayList<ArrayList<Double>> coeff; Instances filtered; SpreadSheet transformed; WekaInstancesToSpreadSheet conv; String colName; result = null; m_Loadings = null; m_Scores = null; if (!m_AttributeRange.isAllRange()) { if (isLoggingEnabled()) getLogger().info("Filtering attribute range: " + m_AttributeRange.getRange()); remove = new Remove(); remove.setAttributeIndicesArray(m_AttributeRange.getIntIndices()); remove.setInvertSelection(true); remove.setInputFormat(data); data = Filter.useFilter(data, remove); } if (isLoggingEnabled()) getLogger().info("Performing PCA..."); listNominal = new TIntArrayList(); if (m_SkipNominal) { for (i = 0; i < data.numAttributes(); i++) { if (i == data.classIndex()) continue; if (data.attribute(i).isNominal()) listNominal.add(i); } } // check for unsupported attributes caps = new PublicPrincipalComponents().getCapabilities(); m_Supported = new TIntArrayList(); m_Unsupported = new TIntArrayList(); for (i = 0; i < data.numAttributes(); i++) { if (!caps.test(data.attribute(i)) || (i == data.classIndex()) || (listNominal.contains(i))) m_Unsupported.add(i); else m_Supported.add(i); } data.setClassIndex(-1); m_NumAttributes = m_Supported.size(); // the principal components will delete the attributes without any distinct values. // this checks which instances will be kept. m_Kept = new ArrayList<>(); for (i = 0; i < m_Supported.size(); i++) { if (data.numDistinctValues(m_Supported.get(i)) > 1) m_Kept.add(m_Supported.get(i)); } // build a model using the PublicPrincipalComponents pca = new PublicPrincipalComponents(); pca.setMaximumAttributes(m_MaxAttributes); pca.setVarianceCovered(m_Variance); pca.setMaximumAttributeNames(m_MaxAttributeNames); part = null; if (m_Unsupported.size() > 0) { rangeUnsupported = new Range(); rangeUnsupported.setMax(data.numAttributes()); rangeUnsupported.setIndices(m_Unsupported.toArray()); rangeSupported = new Range(); rangeSupported.setMax(data.numAttributes()); rangeSupported.setIndices(m_Supported.toArray()); part = new PartitionedMultiFilter2(); part.setFilters(new Filter[] { pca, new AllFilter(), }); part.setRanges(new weka.core.Range[] { new weka.core.Range(rangeSupported.getRange()), new weka.core.Range(rangeUnsupported.getRange()), }); } try { if (part != null) part.setInputFormat(data); else pca.setInputFormat(data); } catch (Exception e) { result = Utils.handleException(this, "Failed to set data format", e); } transformed = null; if (result == null) { try { if (part != null) filtered = weka.filters.Filter.useFilter(data, part); else filtered = weka.filters.Filter.useFilter(data, pca); } catch (Exception e) { result = Utils.handleException(this, "Failed to apply filter", e); filtered = null; } if (filtered != null) { conv = new WekaInstancesToSpreadSheet(); conv.setInput(filtered); result = conv.convert(); if (result == null) { transformed = (SpreadSheet) conv.getOutput(); // shorten column names again if (part != null) { for (i = 0; i < transformed.getColumnCount(); i++) { colName = transformed.getColumnName(i); colName = colName.replaceFirst("filtered-[0-9]*-", ""); transformed.getHeaderRow().getCell(i).setContentAsString(colName); } } } } } if (result == null) { // get the coefficients from the filter m_Scores = transformed; coeff = pca.getCoefficients(); m_Loadings = extractLoadings(data, coeff); m_Loadings.setName("Loadings for " + data.relationName()); } return result; }
From source file:adams.data.instancesanalysis.PLS.java
License:Open Source License
/** * Performs the actual analysis.//from ww w .j a va 2 s . c o m * * @param data the data to analyze * @return null if successful, otherwise error message * @throws Exception if analysis fails */ @Override protected String doAnalyze(Instances data) throws Exception { String result; Remove remove; weka.filters.supervised.attribute.PLS pls; WekaInstancesToSpreadSheet conv; SpreadSheet transformed; Matrix matrix; SpreadSheet loadings; Row row; int i; int n; m_Loadings = null; m_Scores = null; data = new Instances(data); data.deleteWithMissingClass(); if (!m_AttributeRange.isAllRange()) { if (isLoggingEnabled()) getLogger().info("Filtering attribute range: " + m_AttributeRange.getRange()); remove = new Remove(); remove.setAttributeIndicesArray(m_AttributeRange.getIntIndices()); remove.setInvertSelection(true); remove.setInputFormat(data); data = Filter.useFilter(data, remove); } if (isLoggingEnabled()) getLogger().info("Performing PLS..."); pls = new weka.filters.supervised.attribute.PLS(); pls.setAlgorithm(m_Algorithm); pls.setInputFormat(data); data = Filter.useFilter(data, pls); conv = new WekaInstancesToSpreadSheet(); conv.setInput(data); result = conv.convert(); if (result == null) { transformed = (SpreadSheet) conv.getOutput(); matrix = pls.getLoadings(); loadings = new DefaultSpreadSheet(); for (i = 0; i < matrix.getColumnDimension(); i++) loadings.getHeaderRow().addCell("L-" + (i + 1)).setContentAsString("Loading-" + (i + 1)); for (n = 0; n < matrix.getRowDimension(); n++) { row = loadings.addRow(); for (i = 0; i < matrix.getColumnDimension(); i++) row.addCell("L-" + (i + 1)).setContent(matrix.get(n, i)); } m_Loadings = loadings; m_Scores = transformed; } return result; }
From source file:adams.flow.transformer.AbstractWekaPredictionsTransformer.java
License:Open Source License
/** * Filters the data accordingly to the selected attribute range. * * @param data the data to filter/*from www . ja va 2 s .co m*/ * @return the filtered data, null if filtering failed */ protected Instances filterTestData(Instances data) { int[] indices; Remove remove; try { m_TestAttributes.setMax(data.numAttributes()); indices = m_TestAttributes.getIntIndices(); remove = new Remove(); remove.setAttributeIndicesArray(indices); remove.setInvertSelection(true); remove.setInputFormat(data); return Filter.useFilter(data, remove); } catch (Exception e) { getLogger().log(Level.SEVERE, "Failed to filter test data using range: " + m_TestAttributes, e); return null; } }
From source file:adams.flow.transformer.WekaInstancesMerge.java
License:Open Source License
/** * Executes the flow item./*from w w w .ja v a 2 s .c o m*/ * * @return null if everything is fine, otherwise error message */ @Override protected String doExecute() { String result; String[] filesStr; File[] files; int i; Instances output; Instances[] orig; Instances[] inst; Instance[] rows; HashSet ids; int max; TIntList uniqueList; Remove remove; result = null; // get filenames files = null; orig = null; if (m_InputToken.getPayload() instanceof String[]) { filesStr = (String[]) m_InputToken.getPayload(); files = new File[filesStr.length]; for (i = 0; i < filesStr.length; i++) files[i] = new PlaceholderFile(filesStr[i]); } else if (m_InputToken.getPayload() instanceof File[]) { files = (File[]) m_InputToken.getPayload(); } else if (m_InputToken.getPayload() instanceof Instance[]) { rows = (Instance[]) m_InputToken.getPayload(); orig = new Instances[rows.length]; for (i = 0; i < rows.length; i++) { orig[i] = new Instances(rows[i].dataset(), 1); orig[i].add((Instance) rows[i].copy()); } } else if (m_InputToken.getPayload() instanceof Instances[]) { orig = (Instances[]) m_InputToken.getPayload(); } else { throw new IllegalStateException("Unhandled input type: " + m_InputToken.getPayload().getClass()); } try { output = null; // simple merge if (m_UniqueID.length() == 0) { if (files != null) { inst = new Instances[1]; for (i = 0; i < files.length; i++) { if (isStopped()) break; inst[0] = DataSource.read(files[i].getAbsolutePath()); inst[0] = prepareData(inst[0], i); if (i == 0) { output = inst[0]; } else { if (isLoggingEnabled()) getLogger().info("Merging with file #" + (i + 1) + ": " + files[i]); output = Instances.mergeInstances(output, inst[0]); } } } else if (orig != null) { inst = new Instances[1]; for (i = 0; i < orig.length; i++) { if (isStopped()) break; inst[0] = prepareData(orig[i], i); if (i == 0) { output = inst[0]; } else { if (isLoggingEnabled()) getLogger() .info("Merging with dataset #" + (i + 1) + ": " + orig[i].relationName()); output = Instances.mergeInstances(output, inst[0]); } } } } // merge based on row IDs else { m_AttType = -1; max = 0; m_UniqueIDAtts = new ArrayList<>(); if (files != null) { orig = new Instances[files.length]; for (i = 0; i < files.length; i++) { if (isStopped()) break; if (isLoggingEnabled()) getLogger().info("Loading file #" + (i + 1) + ": " + files[i]); orig[i] = DataSource.read(files[i].getAbsolutePath()); max = Math.max(max, orig[i].numInstances()); } } else if (orig != null) { for (i = 0; i < orig.length; i++) max = Math.max(max, orig[i].numInstances()); } inst = new Instances[orig.length]; ids = new HashSet(max); for (i = 0; i < orig.length; i++) { if (isStopped()) break; if (isLoggingEnabled()) getLogger().info("Updating IDs #" + (i + 1)); updateIDs(i, orig[i], ids); if (isLoggingEnabled()) getLogger().info("Preparing dataset #" + (i + 1)); inst[i] = prepareData(orig[i], i); } output = merge(orig, inst, ids); // remove unnecessary unique ID attributes if (m_KeepOnlySingleUniqueID) { uniqueList = new TIntArrayList(); for (String att : m_UniqueIDAtts) uniqueList.add(output.attribute(att).index()); if (uniqueList.size() > 0) { if (isLoggingEnabled()) getLogger().info("Removing duplicate unique ID attributes: " + m_UniqueIDAtts); remove = new Remove(); remove.setAttributeIndicesArray(uniqueList.toArray()); remove.setInputFormat(output); output = Filter.useFilter(output, remove); } } } if (!isStopped()) { m_OutputToken = new Token(output); updateProvenance(m_OutputToken); } } catch (Exception e) { result = handleException("Failed to merge: ", e); } return result; }
From source file:adams.ml.data.InstancesView.java
License:Open Source License
/** * Returns a spreadsheet containing only output columns, i.e., the class * columns.//w w w .j av a2s .co m * * @return the output features, null if data has no class columns */ @Override public SpreadSheet getOutputs() { Instances data; Remove remove; if (m_Data.classIndex() == -1) return null; data = new Instances(m_Data); data.setClassIndex(-1); remove = new Remove(); remove.setAttributeIndicesArray(new int[] { m_Data.classIndex() }); remove.setInvertSelection(true); try { remove.setInputFormat(data); data = Filter.useFilter(data, remove); return new InstancesView(data); } catch (Exception e) { throw new IllegalStateException("Failed to apply Remove filter!", e); } }
From source file:app.RunApp.java
License:Open Source License
/** * Load dataset in principal tab// w w w .j ava2 s . c o m * * @param returnVal Positive number if successfull and negative otherwise * @param fileChooser Chooser * @param deleteXML Boolean indicating if the generated xml must be removed * @return */ private int loadDataset(int returnVal, JFileChooser fileChooser, boolean deleteXML) { if (returnVal == JFileChooser.OPEN_DIALOG) { File f1 = fileChooser.getSelectedFile(); datasetName = f1.getName(); datasetCurrentName = datasetName.substring(0, datasetName.length() - 5); String arffFilename = f1.getAbsolutePath(); xmlPath = arffFilename.substring(0, arffFilename.length() - 5) + ".xml"; xmlFilename = DataIOUtils.getFileName(xmlPath); File fileTmp = new File(xmlPath); FileReader fr; try { views.clear(); ((DefaultTableModel) jTable2.getModel()).getDataVector().removeAllElements(); ((DefaultTableModel) jTable3.getModel()).getDataVector().removeAllElements(); fr = new FileReader(arffFilename); BufferedReader bf = new BufferedReader(fr); String sString = bf.readLine(); if (sString.contains("-V:")) { mv = true; TabPrincipal.setEnabledAt(7, true); String s2 = sString.split("'")[1]; s2 = s2.split("-V:")[1]; String[] intervals = s2.split("!"); Vector<Vector<Integer>> newIntervals = new Vector<>(); int[] intervalsSize = new int[intervals.length]; int max = Integer.MIN_VALUE; int min = Integer.MAX_VALUE; double mean = 0; for (int i = 0; i < intervals.length; i++) { newIntervals.add(new Vector<Integer>()); String[] aux2; viewsIntervals.put("View " + (i + 1), intervals[i]); if (intervals[i].contains(",")) { aux2 = intervals[i].split(","); for (int j = 0; j < aux2.length; j++) { if (aux2[j].contains("-")) { int a = Integer.parseInt(aux2[j].split("-")[0]); int b = Integer.parseInt(aux2[j].split("-")[1]); for (int k = a; k <= b; k++) { newIntervals.get(i).add(k); } } else { newIntervals.get(i).add(Integer.parseInt(aux2[j])); } } } else { if (intervals[i].contains("-")) { int a = Integer.parseInt(intervals[i].split("-")[0]); int b = Integer.parseInt(intervals[i].split("-")[1]); for (int k = a; k <= b; k++) { newIntervals.get(i).add(k); } } else { newIntervals.get(i).add(Integer.parseInt(intervals[i])); } } } for (int i = 0; i < newIntervals.size(); i++) { Integer[] indices = new Integer[newIntervals.get(i).size()]; for (int j = 0; j < newIntervals.get(i).size(); j++) { indices[j] = newIntervals.get(i).get(j); } System.out.println(Arrays.toString(indices)); views.put("View " + (i + 1), indices); if (newIntervals.get(i).size() > max) { max = newIntervals.get(i).size(); } if (newIntervals.get(i).size() < min) { min = newIntervals.get(i).size(); } mean += newIntervals.get(i).size(); } mean /= intervalsSize.length; labelNumViewsValue.setText(Integer.toString(intervalsSize.length)); labelMaxNumAttrViewValue.setText(Integer.toString(max)); labelMinNumAttrViewValue.setText(Integer.toString(min)); labelMeanNumAttrViewValue.setText(Double.toString(mean)); } else { TabPrincipal.setEnabledAt(7, false); mv = false; } int labelFound = 0; String labelName; String[] labelNamesFound; if (DataIOUtils.isMeka(sString)) { deleteXML = true; isMeka = true; int labelCount = DataIOUtils.getLabelsFromARFF(sString); if (labelCount > 0) { labelNamesFound = new String[labelCount]; while (labelFound < labelCount) { sString = bf.readLine(); labelName = DataIOUtils.getLabelNameFromLine(sString); if (labelName != null) { labelNamesFound[labelFound] = labelName; labelFound++; } } } else { labelCount = Math.abs(labelCount); labelNamesFound = new String[labelCount]; String[] sStrings = new String[labelCount]; while (!(sString = bf.readLine()).contains("@data")) { if (!sString.trim().equals("")) { for (int s = 0; s < labelCount - 1; s++) { sStrings[s] = sStrings[s + 1]; } sStrings[labelCount - 1] = sString; } } for (int i = 0; i < labelCount; i++) { labelName = DataIOUtils.getLabelNameFromLine(sStrings[i]); if (labelName != null) { labelNamesFound[labelFound] = labelName; labelFound++; } } } BufferedWriter bwXml = new BufferedWriter(new FileWriter(xmlPath)); PrintWriter wrXml = new PrintWriter(bwXml); DataIOUtils.writeXMLFile(wrXml, labelNamesFound); bwXml.close(); wrXml.close(); xmlFilename = DataIOUtils.getFilePath(xmlPath); fileTmp = new File(xmlPath); } else { isMeka = false; } } catch (FileNotFoundException ex) { Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex); return -1; } catch (IOException ex) { Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex); return -1; } if (!fileTmp.exists()) { xmlPath = DataIOUtils.getXMLString(arffFilename); xmlFilename = DataIOUtils.getFilePath(xmlPath); } //Enable tabs TabPrincipal.setEnabledAt(1, true); TabPrincipal.setEnabledAt(2, true); TabPrincipal.setEnabledAt(3, true); TabPrincipal.setEnabledAt(4, true); TabPrincipal.setEnabledAt(5, true); TabPrincipal.setEnabledAt(6, true); try { File f = new File(xmlFilename); if (f.exists() && !f.isDirectory()) { //MultiLabelInstances dataset_temp = new MultiLabelInstances(filename_database_arff, xmlFilename); } else { JOptionPane.showMessageDialog(null, "File could not be loaded.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } } catch (Exception ex) { Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex); progressBar.setVisible(false); progressFrame.setVisible(false); progressFrame.repaint(); return -1; } initTableMetrics(); clearTableMetricsPrincipal(); File f = new File(xmlFilename); if (f.exists() && !f.isDirectory()) { loadDataset(arffFilename, xmlFilename); } else { loadDataset(arffFilename, null); } if (deleteXML) { File f2 = new File(xmlFilename); f2.delete(); } textChooseFile.setText(arffFilename); } if (mv) { if (((DefaultTableModel) jTable2.getModel()).getRowCount() > 0) { ((DefaultTableModel) jTable2.getModel()).getDataVector().removeAllElements(); } for (int i = 0; i < views.size(); i++) { MultiLabelInstances view = dataset.clone(); try { Instances inst = view.getDataSet(); int[] attributes = Utils.toPrimitive(views.get("View " + (i + 1))); int[] toKeep = new int[attributes.length + dataset.getNumLabels()]; System.arraycopy(attributes, 0, toKeep, 0, attributes.length); int[] labelIndices = dataset.getLabelIndices(); System.arraycopy(labelIndices, 0, toKeep, attributes.length, dataset.getNumLabels()); Remove filterRemove = new Remove(); filterRemove.setAttributeIndicesArray(toKeep); filterRemove.setInvertSelection(true); filterRemove.setInputFormat(inst); MultiLabelInstances modifiedDataset = new MultiLabelInstances( Filter.useFilter(view.getDataSet(), filterRemove), dataset.getLabelsMetaData()); LxIxF lif = new LxIxF(); lif.calculate(modifiedDataset); RatioInstancesToAttributes ratioInstAtt = new RatioInstancesToAttributes(); ratioInstAtt.calculate(modifiedDataset); AvgGainRatio avgGainRatio = new AvgGainRatio(); avgGainRatio.calculate(modifiedDataset); ((DefaultTableModel) jTable2.getModel()).addRow( new Object[] { "View " + (i + 1), attributes.length, getMetricValueFormatted(lif), getMetricValueFormatted(ratioInstAtt), getMetricValueFormatted(avgGainRatio) }); } catch (Exception ex) { Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex); } } } return 1; }
From source file:data.statistics.MLStatistics.java
License:Open Source License
/** * Calculates Phi and Chi-square correlation matrix. * * @param dataSet//from w w w . ja v a 2 s. c o m * A multi-label dataset. * @throws java.lang.Exception * To be handled in an upper level. */ public void calculatePhiChi2(MultiLabelInstances dataSet) throws Exception { numLabels = dataSet.getNumLabels(); // The indices of the label attributes int[] labelIndices; labelIndices = dataSet.getLabelIndices(); numLabels = dataSet.getNumLabels(); phi = new double[numLabels][numLabels]; chi2 = new double[numLabels][numLabels]; Remove remove = new Remove(); remove.setInvertSelection(true); remove.setAttributeIndicesArray(labelIndices); remove.setInputFormat(dataSet.getDataSet()); Instances result = Filter.useFilter(dataSet.getDataSet(), remove); result.setClassIndex(result.numAttributes() - 1); for (int i = 0; i < numLabels; i++) { int a[] = new int[numLabels]; int b[] = new int[numLabels]; int c[] = new int[numLabels]; int d[] = new int[numLabels]; double e[] = new double[numLabels]; double f[] = new double[numLabels]; double g[] = new double[numLabels]; double h[] = new double[numLabels]; for (int j = 0; j < result.numInstances(); j++) { for (int l = 0; l < numLabels; l++) { if (result.instance(j).stringValue(i).equals("0")) { if (result.instance(j).stringValue(l).equals("0")) { a[l]++; } else { c[l]++; } } else { if (result.instance(j).stringValue(l).equals("0")) { b[l]++; } else { d[l]++; } } } } for (int l = 0; l < numLabels; l++) { e[l] = a[l] + b[l]; f[l] = c[l] + d[l]; g[l] = a[l] + c[l]; h[l] = b[l] + d[l]; double mult = e[l] * f[l] * g[l] * h[l]; double denominator = Math.sqrt(mult); double nominator = a[l] * d[l] - b[l] * c[l]; phi[i][l] = nominator / denominator; chi2[i][l] = phi[i][l] * phi[i][l] * (a[l] + b[l] + c[l] + d[l]); } } }
From source file:edu.oregonstate.eecs.mcplan.abstraction.EvaluateSimilarityFunction.java
License:Open Source License
public static ClusterContingencyTable evaluateClassifier(final Classifier classifier, final Instances test) { try {// w w w. j a v a 2s . c o m final Map<Integer, Set<RealVector>> Umap = new TreeMap<Integer, Set<RealVector>>(); final Map<Integer, Set<RealVector>> Vmap = new TreeMap<Integer, Set<RealVector>>(); final Remove rm_filter = new Remove(); rm_filter.setAttributeIndicesArray(new int[] { test.classIndex() }); rm_filter.setInputFormat(test); for (final Instance i : test) { rm_filter.input(i); final double[] phi = rm_filter.output().toDoubleArray(); // final double[] phi = WekaUtil.unlabeledFeatures( i ); final int cluster = (int) classifier.classifyInstance(i); Set<RealVector> u = Umap.get(cluster); if (u == null) { u = new HashSet<RealVector>(); Umap.put(cluster, u); } u.add(new ArrayRealVector(phi)); final int true_label = (int) i.classValue(); Set<RealVector> v = Vmap.get(true_label); if (v == null) { v = new HashSet<RealVector>(); Vmap.put(true_label, v); } v.add(new ArrayRealVector(phi)); } final ArrayList<Set<RealVector>> U = new ArrayList<Set<RealVector>>(); for (final Map.Entry<Integer, Set<RealVector>> e : Umap.entrySet()) { U.add(e.getValue()); } final ArrayList<Set<RealVector>> V = new ArrayList<Set<RealVector>>(); for (final Map.Entry<Integer, Set<RealVector>> e : Vmap.entrySet()) { V.add(e.getValue()); } return new ClusterContingencyTable(U, V); } catch (final RuntimeException ex) { throw ex; } catch (final Exception ex) { throw new RuntimeException(ex); } }