List of usage examples for weka.core Instances relationName
publicString relationName()
From source file:cezeri.evaluater.FactoryEvaluation.java
public static Evaluation performCrossValidate(Classifier model, Instances datax, int folds, boolean show_text, boolean show_plot, TFigureAttribute attr) { Random rand = new Random(1); Instances randData = new Instances(datax); randData.randomize(rand);/* ww w . j a va2 s . co m*/ if (randData.classAttribute().isNominal()) { randData.stratify(folds); } Evaluation eval = null; try { // perform cross-validation eval = new Evaluation(randData); // double[] simulated = new double[0]; // double[] observed = new double[0]; // double[] sim = new double[0]; // double[] obs = new double[0]; for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n, rand); Instances validation = randData.testCV(folds, n); // build and evaluate classifier Classifier clsCopy = Classifier.makeCopy(model); clsCopy.buildClassifier(train); // sim = eval.evaluateModel(clsCopy, validation); // obs = validation.attributeToDoubleArray(validation.classIndex()); // if (show_plot) { // double[][] d = new double[2][sim.length]; // d[0] = obs; // d[1] = sim; // CMatrix f1 = CMatrix.getInstance(d); // f1.transpose().plot(attr); // } // if (show_text) { // // output evaluation // System.out.println(); // System.out.println("=== Setup for each Cross Validation fold==="); // System.out.println("Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions())); // System.out.println("Dataset: " + randData.relationName()); // System.out.println("Folds: " + folds); // System.out.println("Seed: " + 1); // System.out.println(); // System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false)); // } simulated = FactoryUtils.concatenate(simulated, eval.evaluateModel(clsCopy, validation)); observed = FactoryUtils.concatenate(observed, validation.attributeToDoubleArray(validation.classIndex())); // simulated = FactoryUtils.mean(simulated,eval.evaluateModel(clsCopy, validation)); // observed = FactoryUtils.mean(observed,validation.attributeToDoubleArray(validation.classIndex())); } if (show_plot) { double[][] d = new double[2][simulated.length]; d[0] = observed; d[1] = simulated; CMatrix f1 = CMatrix.getInstance(d); attr.figureCaption = "overall performance"; f1.transpose().plot(attr); } if (show_text) { // output evaluation System.out.println(); System.out.println("=== Setup for Overall Cross Validation==="); System.out.println( "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions())); System.out.println("Dataset: " + randData.relationName()); System.out.println("Folds: " + folds); System.out.println("Seed: " + 1); System.out.println(); System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false)); } } catch (Exception ex) { Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex); } return eval; }
From source file:cezeri.evaluater.FactoryEvaluation.java
public static Evaluation performCrossValidateTestAlso(Classifier model, Instances datax, Instances test, boolean show_text, boolean show_plot) { TFigureAttribute attr = new TFigureAttribute(); Random rand = new Random(1); Instances randData = new Instances(datax); randData.randomize(rand);// ww w. jav a 2 s . c om Evaluation eval = null; int folds = randData.numInstances(); try { eval = new Evaluation(randData); for (int n = 0; n < folds; n++) { // randData.randomize(rand); // Instances train = randData; Instances train = randData.trainCV(folds, n); // Instances train = randData.trainCV(folds, n, rand); Classifier clsCopy = Classifier.makeCopy(model); clsCopy.buildClassifier(train); Instances validation = randData.testCV(folds, n); // Instances validation = test.testCV(test.numInstances(), n%test.numInstances()); // CMatrix.fromInstances(train).showDataGrid(); // CMatrix.fromInstances(validation).showDataGrid(); simulated = FactoryUtils.concatenate(simulated, eval.evaluateModel(clsCopy, validation)); observed = FactoryUtils.concatenate(observed, validation.attributeToDoubleArray(validation.classIndex())); } if (show_plot) { double[][] d = new double[2][simulated.length]; d[0] = observed; d[1] = simulated; CMatrix f1 = CMatrix.getInstance(d); attr.figureCaption = "overall performance"; f1.transpose().plot(attr); } if (show_text) { // output evaluation System.out.println(); System.out.println("=== Setup for Overall Cross Validation==="); System.out.println( "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions())); System.out.println("Dataset: " + randData.relationName()); System.out.println("Folds: " + folds); System.out.println("Seed: " + 1); System.out.println(); System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false)); } } catch (Exception ex) { Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex); } return eval; }
From source file:cezeri.evaluater.FactoryEvaluation.java
private static Evaluation doTest(boolean isTrained, Classifier model, Instances train, Instances test, boolean show_text, boolean show_plot, TFigureAttribute attr) { Instances data = new Instances(train); Random rand = new Random(1); data.randomize(rand);/*from w w w .j av a 2 s . c o m*/ Evaluation eval = null; try { // double[] simulated = null; eval = new Evaluation(train); if (isTrained) { simulated = eval.evaluateModel(model, test); } else { Classifier clsCopy = Classifier.makeCopy(model); clsCopy.buildClassifier(train); simulated = eval.evaluateModel(clsCopy, test); } if (show_plot) { observed = test.attributeToDoubleArray(test.classIndex()); double[][] d = new double[2][simulated.length]; d[0] = observed; d[1] = simulated; CMatrix f1 = CMatrix.getInstance(d); String[] items = { "Observed", "Simulated" }; attr.items = items; attr.figureCaption = model.getClass().getCanonicalName(); f1.transpose().plot(attr); // if (attr.axis[0].isEmpty() && attr.axis[1].isEmpty()) { // f1.transpose().plot(attr); // } else { // f1.transpose().plot(model.getClass().getCanonicalName(), attr.items, attr.axis); // } } if (show_text) { System.out.println(); System.out.println("=== Setup for Test ==="); System.out.println( "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions())); System.out.println("Dataset: " + test.relationName()); System.out.println(); System.out.println(eval.toSummaryString("=== Test Results ===", false)); } } catch (Exception ex) { Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex); } return eval; }
From source file:cezeri.feature.selection.FeatureSelectionInfluence.java
public static Evaluation getEvaluation(Instances randData, Classifier model, int folds) { Evaluation eval = null;/*from w w w . j a va2 s . com*/ try { eval = new Evaluation(randData); for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // build and evaluate classifier Classifier clsCopy = Classifier.makeCopy(model); clsCopy.buildClassifier(train); eval.evaluateModel(clsCopy, test); // double[] prediction = eval.evaluateModel(clsCopy, test); // double[] original = getAttributeValues(test); // double[][] d = new double[2][prediction.length]; // d[0] = prediction; // d[1] = original; // CMatrix f1 = new CMatrix(d); } // output evaluation System.out.println(); System.out.println("=== Setup ==="); System.out.println( "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions())); System.out.println("Dataset: " + randData.relationName()); System.out.println("Folds: " + folds); System.out.println(); System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false)); System.out.println(eval.toClassDetailsString("=== Detailed Accuracy By Class ===")); System.out.println(eval.toMatrixString("Confusion Matrix")); double acc = eval.correct() / eval.numInstances() * 100; System.out.println("correct:" + eval.correct() + " " + acc + "%"); } catch (Exception ex) { Logger.getLogger(FeatureSelectionInfluence.class.getName()).log(Level.SEVERE, null, ex); } return eval; }
From source file:cezeri.utils.FactoryInstance.java
/** * * @param m tm dataset//from www .ja v a2 s. c o m * @param val class value deeri val olanlar filtrele * @return */ public static Instances[] getSpecificInstancesBasedOnClassValue(Instances m, String[] cl) { Instances[] ret = new Instances[cl.length]; for (int i = 0; i < ret.length; i++) { ret[i] = FactoryInstance.generateInstances(m.relationName() + "_class=" + cl[i], m.numAttributes()); // ret[i] = m.resampleWithWeights(new Random()); ret[i].delete(); } for (int i = 0; i < m.numInstances(); i++) { Instance ins = m.instance(i); for (int j = 0; j < cl.length; j++) { if (("" + (int) ins.classValue()).equals(cl[j])) { ret[j].add(ins); } } } return ret; }
From source file:cezeri.utils.FactoryInstance.java
public static Instances getSubsetData(Instances data, String[] attList) { Instances temp = new Instances(data); for (int i = 0; i < data.numAttributes(); i++) { if (!temp.attribute(0).equals(temp.classAttribute())) { temp.deleteAttributeAt(0);/* w w w.j ava 2 s . c o m*/ } } double[][] m = new double[attList.length + 1][data.numInstances()]; for (int i = 0; i < attList.length; i++) { int n = attList.length - 1 - i; String str = attList[n]; Attribute t = data.attribute(str); double[] d = data.attributeToDoubleArray(t.index()); m[n] = d; temp.insertAttributeAt(t, 0); } m[attList.length] = data.attributeToDoubleArray(data.classIndex()); m = CMatrix.getInstance(m).transpose().get2DArrayDouble(); FastVector att = new FastVector(); for (int i = 0; i < temp.numAttributes(); i++) { att.addElement(temp.attribute(i)); } Instances ret = new Instances(temp.relationName(), att, m.length); for (int i = 0; i < m.length; i++) { Instance ins = new Instance(m[0].length); for (int j = 0; j < m[0].length; j++) { ins.setValue(j, m[i][j]); } ret.add(ins); } ret.setClassIndex(temp.classIndex()); return ret; }
From source file:com.sliit.views.DataVisualizerPanel.java
void getRocCurve() { try {/* w w w . j a v a2 s . co m*/ Instances data; data = new Instances(new BufferedReader(new FileReader(datasetPathText.getText()))); data.setClassIndex(data.numAttributes() - 1); // train classifier Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); // display curve String plotName = vmc.getName(); final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: " + plotName); jf.setSize(500, 400); jf.getContentPane().setLayout(new BorderLayout()); jf.getContentPane().add(vmc, BorderLayout.CENTER); jf.addWindowListener(new java.awt.event.WindowAdapter() { public void windowClosing(java.awt.event.WindowEvent e) { jf.dispose(); } }); jf.setVisible(true); } catch (IOException ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:com.sliit.views.KNNView.java
void getRocCurve() { try {/* w ww . j a v a 2s . c om*/ Instances data; data = new Instances(new BufferedReader(new java.io.FileReader(PredictorPanel.modalText.getText()))); data.setClassIndex(data.numAttributes() - 1); // train classifier Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); rocPanel.removeAll(); rocPanel.add(vmc, "vmc", 0); rocPanel.revalidate(); } catch (IOException ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:com.sliit.views.SVMView.java
/** * draw ROC curve// w w w . j av a 2 s. c om */ void getRocCurve() { try { Instances data; data = new Instances(new BufferedReader(new FileReader(PredictorPanel.modalText.getText()))); data.setClassIndex(data.numAttributes() - 1); //train classifier Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); // rocPanel.removeAll(); // rocPanel.add(vmc, "vmc", 0); // rocPanel.revalidate(); } catch (IOException ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:com.yahoo.labs.samoa.instances.WekaToSamoaInstanceConverter.java
License:Apache License
/** * Samoa instances information./* w w w. j a v a2 s .com*/ * * @param instances the instances * @return the instances */ public Instances samoaInstancesInformation(weka.core.Instances instances) { Instances samoaInstances; List<Attribute> attInfo = new ArrayList<Attribute>(); for (int i = 0; i < instances.numAttributes(); i++) { attInfo.add(samoaAttribute(i, instances.attribute(i))); } samoaInstances = new Instances(instances.relationName(), attInfo, 0); samoaInstances.setClassIndex(instances.classIndex()); return samoaInstances; }