List of usage examples for weka.core Instances setClassIndex
public void setClassIndex(int classIndex)
From source file:core.classification.Classifiers.java
License:Open Source License
public void trainYNC() throws Exception { // ---//from ww w. ja v a 2s . c o m // Retrieve the instances in the database // --- InstanceQuery query = new InstanceQuery(); query.setDatabaseURL(dbase); query.setUsername(""); query.setPassword(""); String sql = "SELECT "; sql += "YNCdata.PCLASS, YNCdata.CCLASS, YNCdata.RAREA, YNCdata.H, YNCdata.D, YNCdata.V, "; sql += "YNCdata.YN "; sql += "FROM YNCdata "; query.setQuery(sql); Instances data = query.retrieveInstances(); // --- // Setting options // --- String[] options = Utils.splitOptions("-R -N 3 -Q 1 -M 30"); YNC.setOptions(options); data.setClassIndex(data.numAttributes() - 1); // --- // Train // --- System.out.println("Building YC..."); YNC.buildClassifier(data); System.out.println("Done."); // --- // Evaluation // --- System.out.println("Cross-validation for YNC..."); Evaluation eval = new Evaluation(data); eval.crossValidateModel(YNC, data, 10, new Random(1)); System.out.println("Done."); System.out.println(eval.toSummaryString("\n Results for YNC: \n\n", false)); }
From source file:core.ClusterEvaluationEX.java
License:Open Source License
/** * Evaluates a clusterer with the options given in an array of * strings. It takes the string indicated by "-t" as training file, the * string indicated by "-T" as test file. * If the test file is missing, a stratified ten-fold * cross-validation is performed (distribution clusterers only). * Using "-x" you can change the number of * folds to be used, and using "-s" the random seed. * If the "-p" option is present it outputs the classification for * each test instance. If you provide the name of an object file using * "-l", a clusterer will be loaded from the given file. If you provide the * name of an object file using "-d", the clusterer built from the * training data will be saved to the given file. * * @param clusterer machine learning clusterer * @param options the array of string containing the options * @throws Exception if model could not be evaluated successfully * @return a string describing the results *///ww w . j a v a 2s . c om public static String evaluateClusterer(Clusterer clusterer, String[] options) throws Exception { int seed = 1, folds = 10; boolean doXval = false; Instances train = null; Random random; String trainFileName, testFileName, seedString, foldsString; String objectInputFileName, objectOutputFileName, attributeRangeString; String graphFileName; String[] savedOptions = null; boolean printClusterAssignments = false; Range attributesToOutput = null; StringBuffer text = new StringBuffer(); int theClass = -1; // class based evaluation of clustering boolean updateable = (clusterer instanceof UpdateableClusterer); DataSource source = null; Instance inst; if (Utils.getFlag('h', options) || Utils.getFlag("help", options)) { // global info requested as well? boolean globalInfo = Utils.getFlag("synopsis", options) || Utils.getFlag("info", options); throw new Exception("Help requested." + makeOptionString(clusterer, globalInfo)); } try { // Get basic options (options the same for all clusterers //printClusterAssignments = Utils.getFlag('p', options); objectInputFileName = Utils.getOption('l', options); objectOutputFileName = Utils.getOption('d', options); trainFileName = Utils.getOption('t', options); testFileName = Utils.getOption('T', options); graphFileName = Utils.getOption('g', options); // Check -p option try { attributeRangeString = Utils.getOption('p', options); } catch (Exception e) { throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " + "It now expects a parameter specifying a range of attributes " + "to list with the predictions. Use '-p 0' for none."); } if (attributeRangeString.length() != 0) { printClusterAssignments = true; if (!attributeRangeString.equals("0")) attributesToOutput = new Range(attributeRangeString); } if (trainFileName.length() == 0) { if (objectInputFileName.length() == 0) { throw new Exception("No training file and no object " + "input file given."); } if (testFileName.length() == 0) { throw new Exception("No training file and no test file given."); } } else { if ((objectInputFileName.length() != 0) && (printClusterAssignments == false)) { throw new Exception("Can't use both train and model file " + "unless -p specified."); } } seedString = Utils.getOption('s', options); if (seedString.length() != 0) { seed = Integer.parseInt(seedString); } foldsString = Utils.getOption('x', options); if (foldsString.length() != 0) { folds = Integer.parseInt(foldsString); doXval = true; } } catch (Exception e) { throw new Exception('\n' + e.getMessage() + makeOptionString(clusterer, false)); } try { if (trainFileName.length() != 0) { source = new DataSource(trainFileName); train = source.getStructure(); String classString = Utils.getOption('c', options); if (classString.length() != 0) { if (classString.compareTo("last") == 0) theClass = train.numAttributes(); else if (classString.compareTo("first") == 0) theClass = 1; else theClass = Integer.parseInt(classString); if (theClass != -1) { if (doXval || testFileName.length() != 0) throw new Exception("Can only do class based evaluation on the " + "training data"); if (objectInputFileName.length() != 0) throw new Exception("Can't load a clusterer and do class based " + "evaluation"); if (objectOutputFileName.length() != 0) throw new Exception("Can't do class based evaluation and save clusterer"); } } else { // if the dataset defines a class attribute, use it if (train.classIndex() != -1) { theClass = train.classIndex() + 1; System.err .println("Note: using class attribute from dataset, i.e., attribute #" + theClass); } } if (theClass != -1) { if (theClass < 1 || theClass > train.numAttributes()) throw new Exception("Class is out of range!"); if (!train.attribute(theClass - 1).isNominal()) throw new Exception("Class must be nominal!"); train.setClassIndex(theClass - 1); } } } catch (Exception e) { throw new Exception("ClusterEvaluation: " + e.getMessage() + '.'); } // Save options if (options != null) { savedOptions = new String[options.length]; System.arraycopy(options, 0, savedOptions, 0, options.length); } if (objectInputFileName.length() != 0) Utils.checkForRemainingOptions(options); // Set options for clusterer if (clusterer instanceof OptionHandler) ((OptionHandler) clusterer).setOptions(options); Utils.checkForRemainingOptions(options); Instances trainHeader = train; if (objectInputFileName.length() != 0) { // Load the clusterer from file // clusterer = (Clusterer) SerializationHelper.read(objectInputFileName); java.io.ObjectInputStream ois = new java.io.ObjectInputStream( new java.io.BufferedInputStream(new java.io.FileInputStream(objectInputFileName))); clusterer = (Clusterer) ois.readObject(); // try and get the training header try { trainHeader = (Instances) ois.readObject(); } catch (Exception ex) { // don't moan if we cant } } else { // Build the clusterer if no object file provided if (theClass == -1) { if (updateable) { clusterer.buildClusterer(source.getStructure()); while (source.hasMoreElements(train)) { inst = source.nextElement(train); ((UpdateableClusterer) clusterer).updateClusterer(inst); } ((UpdateableClusterer) clusterer).updateFinished(); } else { clusterer.buildClusterer(source.getDataSet()); } } else { Remove removeClass = new Remove(); removeClass.setAttributeIndices("" + theClass); removeClass.setInvertSelection(false); removeClass.setInputFormat(train); if (updateable) { Instances clusterTrain = Filter.useFilter(train, removeClass); clusterer.buildClusterer(clusterTrain); trainHeader = clusterTrain; while (source.hasMoreElements(train)) { inst = source.nextElement(train); removeClass.input(inst); removeClass.batchFinished(); Instance clusterTrainInst = removeClass.output(); ((UpdateableClusterer) clusterer).updateClusterer(clusterTrainInst); } ((UpdateableClusterer) clusterer).updateFinished(); } else { Instances clusterTrain = Filter.useFilter(source.getDataSet(), removeClass); clusterer.buildClusterer(clusterTrain); trainHeader = clusterTrain; } ClusterEvaluationEX ce = new ClusterEvaluationEX(); ce.setClusterer(clusterer); ce.evaluateClusterer(train, trainFileName); return "\n\n=== Clustering stats for training data ===\n\n" + ce.clusterResultsToString(); } } /* Output cluster predictions only (for the test data if specified, otherwise for the training data */ if (printClusterAssignments) { return printClusterings(clusterer, trainFileName, testFileName, attributesToOutput); } text.append(clusterer.toString()); text.append( "\n\n=== Clustering stats for training data ===\n\n" + printClusterStats(clusterer, trainFileName)); if (testFileName.length() != 0) { // check header compatibility DataSource test = new DataSource(testFileName); Instances testStructure = test.getStructure(); if (!trainHeader.equalHeaders(testStructure)) { throw new Exception("Training and testing data are not compatible\n"); } text.append("\n\n=== Clustering stats for testing data ===\n\n" + printClusterStats(clusterer, testFileName)); } if ((clusterer instanceof DensityBasedClusterer) && (doXval == true) && (testFileName.length() == 0) && (objectInputFileName.length() == 0)) { // cross validate the log likelihood on the training data random = new Random(seed); random.setSeed(seed); train = source.getDataSet(); train.randomize(random); text.append(crossValidateModel(clusterer.getClass().getName(), train, folds, savedOptions, random)); } // Save the clusterer if an object output file is provided if (objectOutputFileName.length() != 0) { //SerializationHelper.write(objectOutputFileName, clusterer); saveClusterer(objectOutputFileName, clusterer, trainHeader); } // If classifier is drawable output string describing graph if ((clusterer instanceof Drawable) && (graphFileName.length() != 0)) { BufferedWriter writer = new BufferedWriter(new FileWriter(graphFileName)); writer.write(((Drawable) clusterer).graph()); writer.newLine(); writer.flush(); writer.close(); } return text.toString(); }
From source file:core.Core.java
public String run() throws Exception { ConverterUtils.DataSource source = new ConverterUtils.DataSource("src/files/powerpuffgirls.arff"); HashMap<String, Classifier> hash = new HashMap<>(); hash.put("J48", new J48()); hash.put("NaiveBayes", new NaiveBayes()); hash.put("IBk=1", new IBk(1)); hash.put("IBk=3", new IBk(3)); hash.put("MultilayerPerceptron", new MultilayerPerceptron()); LibSVM svm = new LibSVM(); hash.put("LibSVM", svm); Instances ins = source.getDataSet(); ins.setClassIndex(4); StringBuilder sb = new StringBuilder(); int blossom = 0; int bubbles = 0; Instance test = null;/* ww w . j a va2 s . co m*/ for (Map.Entry<String, Classifier> entry : hash.entrySet()) { Classifier c = entry.getValue(); c.buildClassifier(ins); test = new Instance(5); float[] array = classifyImage(); test.setDataset(ins); test.setValue(0, array[0]); test.setValue(1, array[1]); test.setValue(2, array[2]); test.setValue(3, array[3]); double prob[] = c.distributionForInstance(test); sb.append("<em>"); sb.append(entry.getKey()); sb.append(":</em>"); sb.append("<br/>"); for (int i = 0; i < prob.length; i++) { String value = test.classAttribute().value(i); if (getRoundedValue(prob[i]) >= CUT_NOTE) { if (getClassValue(value)) blossom++; else bubbles++; } sb.append(getClassName(value)); sb.append(": "); sb.append("<strong>"); sb.append(getRoundedValue(prob[i]) < CUT_NOTE ? "Rejeitado!" : getValueFormatted(prob[i])); sb.append("</strong>"); sb.append(" "); } sb.append("<br/>"); System.out.println("blossom: " + blossom); System.out.println("bubbles: " + bubbles); System.out.println("=================\n"); } sb.append(blossom > bubbles ? "<h3> a Florzinha!</h3>" : "<h3> a Lindinha!</h3>"); blossom = 0; bubbles = 0; return sb.toString(); }
From source file:cotraining.copy.Evaluation_D.java
License:Open Source License
/** * Evaluates a classifier with the options given in an array of * strings. <p/>//from w w w . java2 s .c o m * * Valid options are: <p/> * * -t name of training file <br/> * Name of the file with the training data. (required) <p/> * * -T name of test file <br/> * Name of the file with the test data. If missing a cross-validation * is performed. <p/> * * -c class index <br/> * Index of the class attribute (1, 2, ...; default: last). <p/> * * -x number of folds <br/> * The number of folds for the cross-validation (default: 10). <p/> * * -no-cv <br/> * No cross validation. If no test file is provided, no evaluation * is done. <p/> * * -split-percentage percentage <br/> * Sets the percentage for the train/test set split, e.g., 66. <p/> * * -preserve-order <br/> * Preserves the order in the percentage split instead of randomizing * the data first with the seed value ('-s'). <p/> * * -s seed <br/> * Random number seed for the cross-validation and percentage split * (default: 1). <p/> * * -m file with cost matrix <br/> * The name of a file containing a cost matrix. <p/> * * -l filename <br/> * Loads classifier from the given file. In case the filename ends with * ".xml",a PMML file is loaded or, if that fails, options are loaded from XML. <p/> * * -d filename <br/> * Saves classifier built from the training data into the given file. In case * the filename ends with ".xml" the options are saved XML, not the model. <p/> * * -v <br/> * Outputs no statistics for the training data. <p/> * * -o <br/> * Outputs statistics only, not the classifier. <p/> * * -i <br/> * Outputs detailed information-retrieval statistics per class. <p/> * * -k <br/> * Outputs information-theoretic statistics. <p/> * * -p range <br/> * Outputs predictions for test instances (or the train instances if no test * instances provided and -no-cv is used), along with the attributes in the specified range * (and nothing else). Use '-p 0' if no attributes are desired. <p/> * * -distribution <br/> * Outputs the distribution instead of only the prediction * in conjunction with the '-p' option (only nominal classes). <p/> * * -r <br/> * Outputs cumulative margin distribution (and nothing else). <p/> * * -g <br/> * Only for classifiers that implement "Graphable." Outputs * the graph representation of the classifier (and nothing * else). <p/> * * -xml filename | xml-string <br/> * Retrieves the options from the XML-data instead of the command line. <p/> * * @param classifier machine learning classifier * @param options the array of string containing the options * @throws Exception if model could not be evaluated successfully * @return a string describing the results */ public static String evaluateModel(Classifier classifier, String[] options) throws Exception { Instances train = null, tempTrain, test = null, template = null; int seed = 1, folds = 10, classIndex = -1; boolean noCrossValidation = false; String trainFileName, testFileName, sourceClass, classIndexString, seedString, foldsString, objectInputFileName, objectOutputFileName, attributeRangeString; boolean noOutput = false, printClassifications = false, trainStatistics = true, printMargins = false, printComplexityStatistics = false, printGraph = false, classStatistics = false, printSource = false; StringBuffer text = new StringBuffer(); DataSource trainSource = null, testSource = null; ObjectInputStream objectInputStream = null; BufferedInputStream xmlInputStream = null; CostMatrix costMatrix = null; StringBuffer schemeOptionsText = null; Range attributesToOutput = null; long trainTimeStart = 0, trainTimeElapsed = 0, testTimeStart = 0, testTimeElapsed = 0; String xml = ""; String[] optionsTmp = null; Classifier classifierBackup; Classifier classifierClassifications = null; boolean printDistribution = false; int actualClassIndex = -1; // 0-based class index String splitPercentageString = ""; int splitPercentage = -1; boolean preserveOrder = false; boolean trainSetPresent = false; boolean testSetPresent = false; String thresholdFile; String thresholdLabel; StringBuffer predsBuff = null; // predictions from cross-validation // help requested? if (Utils.getFlag("h", options) || Utils.getFlag("help", options)) { // global info requested as well? boolean globalInfo = Utils.getFlag("synopsis", options) || Utils.getFlag("info", options); throw new Exception("\nHelp requested." + makeOptionString(classifier, globalInfo)); } try { // do we get the input from XML instead of normal parameters? xml = Utils.getOption("xml", options); if (!xml.equals("")) options = new XMLOptions(xml).toArray(); // is the input model only the XML-Options, i.e. w/o built model? optionsTmp = new String[options.length]; for (int i = 0; i < options.length; i++) optionsTmp[i] = options[i]; String tmpO = Utils.getOption('l', optionsTmp); //if (Utils.getOption('l', optionsTmp).toLowerCase().endsWith(".xml")) { if (tmpO.endsWith(".xml")) { // try to load file as PMML first boolean success = false; try { //PMMLModel pmmlModel = PMMLFactory.getPMMLModel(tmpO); //if (pmmlModel instanceof PMMLClassifier) { //classifier = ((PMMLClassifier)pmmlModel); // success = true; //} } catch (IllegalArgumentException ex) { success = false; } if (!success) { // load options from serialized data ('-l' is automatically erased!) XMLClassifier xmlserial = new XMLClassifier(); Classifier cl = (Classifier) xmlserial.read(Utils.getOption('l', options)); // merge options optionsTmp = new String[options.length + cl.getOptions().length]; System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl.getOptions().length); System.arraycopy(options, 0, optionsTmp, cl.getOptions().length, options.length); options = optionsTmp; } } noCrossValidation = Utils.getFlag("no-cv", options); // Get basic options (options the same for all schemes) classIndexString = Utils.getOption('c', options); if (classIndexString.length() != 0) { if (classIndexString.equals("first")) classIndex = 1; else if (classIndexString.equals("last")) classIndex = -1; else classIndex = Integer.parseInt(classIndexString); } trainFileName = Utils.getOption('t', options); objectInputFileName = Utils.getOption('l', options); objectOutputFileName = Utils.getOption('d', options); testFileName = Utils.getOption('T', options); foldsString = Utils.getOption('x', options); if (foldsString.length() != 0) { folds = Integer.parseInt(foldsString); } seedString = Utils.getOption('s', options); if (seedString.length() != 0) { seed = Integer.parseInt(seedString); } if (trainFileName.length() == 0) { if (objectInputFileName.length() == 0) { throw new Exception("No training file and no object " + "input file given."); } if (testFileName.length() == 0) { throw new Exception("No training file and no test " + "file given."); } } else if ((objectInputFileName.length() != 0) && ((!(classifier instanceof UpdateableClassifier)) || (testFileName.length() == 0))) { throw new Exception("Classifier not incremental, or no " + "test file provided: can't " + "use both train and model file."); } try { if (trainFileName.length() != 0) { trainSetPresent = true; trainSource = new DataSource(trainFileName); } if (testFileName.length() != 0) { testSetPresent = true; testSource = new DataSource(testFileName); } if (objectInputFileName.length() != 0) { if (objectInputFileName.endsWith(".xml")) { // if this is the case then it means that a PMML classifier was // successfully loaded earlier in the code objectInputStream = null; xmlInputStream = null; } else { InputStream is = new FileInputStream(objectInputFileName); if (objectInputFileName.endsWith(".gz")) { is = new GZIPInputStream(is); } // load from KOML? if (!(objectInputFileName.endsWith(".koml") && KOML.isPresent())) { objectInputStream = new ObjectInputStream(is); xmlInputStream = null; } else { objectInputStream = null; xmlInputStream = new BufferedInputStream(is); } } } } catch (Exception e) { throw new Exception("Can't open file " + e.getMessage() + '.'); } if (testSetPresent) { template = test = testSource.getStructure(); if (classIndex != -1) { test.setClassIndex(classIndex - 1); } else { if ((test.classIndex() == -1) || (classIndexString.length() != 0)) test.setClassIndex(test.numAttributes() - 1); } actualClassIndex = test.classIndex(); } else { // percentage split splitPercentageString = Utils.getOption("split-percentage", options); if (splitPercentageString.length() != 0) { if (foldsString.length() != 0) throw new Exception("Percentage split cannot be used in conjunction with " + "cross-validation ('-x')."); splitPercentage = Integer.parseInt(splitPercentageString); if ((splitPercentage <= 0) || (splitPercentage >= 100)) throw new Exception("Percentage split value needs be >0 and <100."); } else { splitPercentage = -1; } preserveOrder = Utils.getFlag("preserve-order", options); if (preserveOrder) { if (splitPercentage == -1) throw new Exception("Percentage split ('-percentage-split') is missing."); } // create new train/test sources if (splitPercentage > 0) { testSetPresent = true; Instances tmpInst = trainSource.getDataSet(actualClassIndex); if (!preserveOrder) tmpInst.randomize(new Random(seed)); int trainSize = tmpInst.numInstances() * splitPercentage / 100; int testSize = tmpInst.numInstances() - trainSize; Instances trainInst = new Instances(tmpInst, 0, trainSize); Instances testInst = new Instances(tmpInst, trainSize, testSize); trainSource = new DataSource(trainInst); testSource = new DataSource(testInst); template = test = testSource.getStructure(); if (classIndex != -1) { test.setClassIndex(classIndex - 1); } else { if ((test.classIndex() == -1) || (classIndexString.length() != 0)) test.setClassIndex(test.numAttributes() - 1); } actualClassIndex = test.classIndex(); } } if (trainSetPresent) { template = train = trainSource.getStructure(); if (classIndex != -1) { train.setClassIndex(classIndex - 1); } else { if ((train.classIndex() == -1) || (classIndexString.length() != 0)) train.setClassIndex(train.numAttributes() - 1); } actualClassIndex = train.classIndex(); if ((testSetPresent) && !test.equalHeaders(train)) { throw new IllegalArgumentException("Train and test file not compatible!"); } } if (template == null) { throw new Exception("No actual dataset provided to use as template"); } costMatrix = handleCostOption(Utils.getOption('m', options), template.numClasses()); classStatistics = Utils.getFlag('i', options); noOutput = Utils.getFlag('o', options); trainStatistics = !Utils.getFlag('v', options); printComplexityStatistics = Utils.getFlag('k', options); printMargins = Utils.getFlag('r', options); printGraph = Utils.getFlag('g', options); sourceClass = Utils.getOption('z', options); printSource = (sourceClass.length() != 0); printDistribution = Utils.getFlag("distribution", options); thresholdFile = Utils.getOption("threshold-file", options); thresholdLabel = Utils.getOption("threshold-label", options); // Check -p option try { attributeRangeString = Utils.getOption('p', options); } catch (Exception e) { throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " + "It now expects a parameter specifying a range of attributes " + "to list with the predictions. Use '-p 0' for none."); } if (attributeRangeString.length() != 0) { printClassifications = true; noOutput = true; if (!attributeRangeString.equals("0")) attributesToOutput = new Range(attributeRangeString); } if (!printClassifications && printDistribution) throw new Exception("Cannot print distribution without '-p' option!"); // if no training file given, we don't have any priors if ((!trainSetPresent) && (printComplexityStatistics)) throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!"); // If a model file is given, we can't process // scheme-specific options if (objectInputFileName.length() != 0) { Utils.checkForRemainingOptions(options); } else { // Set options for classifier if (classifier instanceof OptionHandler) { for (int i = 0; i < options.length; i++) { if (options[i].length() != 0) { if (schemeOptionsText == null) { schemeOptionsText = new StringBuffer(); } if (options[i].indexOf(' ') != -1) { schemeOptionsText.append('"' + options[i] + "\" "); } else { schemeOptionsText.append(options[i] + " "); } } } ((OptionHandler) classifier).setOptions(options); } } Utils.checkForRemainingOptions(options); } catch (Exception e) { throw new Exception("\nWeka exception: " + e.getMessage() + makeOptionString(classifier, false)); } // Setup up evaluation objects Evaluation_D trainingEvaluation = new Evaluation_D(new Instances(template, 0), costMatrix); Evaluation_D testingEvaluation = new Evaluation_D(new Instances(template, 0), costMatrix); // disable use of priors if no training file given if (!trainSetPresent) testingEvaluation.useNoPriors(); if (objectInputFileName.length() != 0) { // Load classifier from file if (objectInputStream != null) { classifier = (Classifier) objectInputStream.readObject(); // try and read a header (if present) Instances savedStructure = null; try { savedStructure = (Instances) objectInputStream.readObject(); } catch (Exception ex) { // don't make a fuss } if (savedStructure != null) { // test for compatibility with template if (!template.equalHeaders(savedStructure)) { throw new Exception("training and test set are not compatible"); } } objectInputStream.close(); } else if (xmlInputStream != null) { // whether KOML is available has already been checked (objectInputStream would null otherwise)! classifier = (Classifier) KOML.read(xmlInputStream); xmlInputStream.close(); } } // backup of fully setup classifier for cross-validation classifierBackup = Classifier.makeCopy(classifier); // Build the classifier if no object file provided if ((classifier instanceof UpdateableClassifier) && (testSetPresent || noCrossValidation) && (costMatrix == null) && (trainSetPresent)) { // Build classifier incrementally trainingEvaluation.setPriors(train); testingEvaluation.setPriors(train); trainTimeStart = System.currentTimeMillis(); if (objectInputFileName.length() == 0) { classifier.buildClassifier(train); } Instance trainInst; while (trainSource.hasMoreElements(train)) { trainInst = trainSource.nextElement(train); trainingEvaluation.updatePriors(trainInst); testingEvaluation.updatePriors(trainInst); ((UpdateableClassifier) classifier).updateClassifier(trainInst); } trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; } else if (objectInputFileName.length() == 0) { // Build classifier in one go tempTrain = trainSource.getDataSet(actualClassIndex); trainingEvaluation.setPriors(tempTrain); testingEvaluation.setPriors(tempTrain); trainTimeStart = System.currentTimeMillis(); classifier.buildClassifier(tempTrain); trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; } // backup of fully trained classifier for printing the classifications if (printClassifications) classifierClassifications = Classifier.makeCopy(classifier); // Save the classifier if an object output file is provided if (objectOutputFileName.length() != 0) { OutputStream os = new FileOutputStream(objectOutputFileName); // binary if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) { if (objectOutputFileName.endsWith(".gz")) { os = new GZIPOutputStream(os); } ObjectOutputStream objectOutputStream = new ObjectOutputStream(os); objectOutputStream.writeObject(classifier); if (template != null) { objectOutputStream.writeObject(template); } objectOutputStream.flush(); objectOutputStream.close(); } // KOML/XML else { BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os); if (objectOutputFileName.endsWith(".xml")) { XMLSerialization xmlSerial = new XMLClassifier(); xmlSerial.write(xmlOutputStream, classifier); } else // whether KOML is present has already been checked // if not present -> ".koml" is interpreted as binary - see above if (objectOutputFileName.endsWith(".koml")) { KOML.write(xmlOutputStream, classifier); } xmlOutputStream.close(); } } // If classifier is drawable output string describing graph if ((classifier instanceof Drawable) && (printGraph)) { return ((Drawable) classifier).graph(); } // Output the classifier as equivalent source if ((classifier instanceof Sourcable) && (printSource)) { return wekaStaticWrapper((Sourcable) classifier, sourceClass); } // Output model if (!(noOutput || printMargins)) { if (classifier instanceof OptionHandler) { if (schemeOptionsText != null) { text.append("\nOptions: " + schemeOptionsText); text.append("\n"); } } text.append("\n" + classifier.toString() + "\n"); } if (!printMargins && (costMatrix != null)) { text.append("\n=== Evaluation Cost Matrix ===\n\n"); text.append(costMatrix.toString()); } // Output test instance predictions only if (printClassifications) { DataSource source = testSource; predsBuff = new StringBuffer(); // no test set -> use train set if (source == null && noCrossValidation) { source = trainSource; predsBuff.append("\n=== Predictions on training data ===\n\n"); } else { predsBuff.append("\n=== Predictions on test data ===\n\n"); } if (source != null) { /* return printClassifications(classifierClassifications, new Instances(template, 0), source, actualClassIndex + 1, attributesToOutput, printDistribution); */ printClassifications(classifierClassifications, new Instances(template, 0), source, actualClassIndex + 1, attributesToOutput, printDistribution, predsBuff); // return predsText.toString(); } } // Compute error estimate from training data if ((trainStatistics) && (trainSetPresent)) { if ((classifier instanceof UpdateableClassifier) && (testSetPresent) && (costMatrix == null)) { // Classifier was trained incrementally, so we have to // reset the source. trainSource.reset(); // Incremental testing train = trainSource.getStructure(actualClassIndex); testTimeStart = System.currentTimeMillis(); Instance trainInst; while (trainSource.hasMoreElements(train)) { trainInst = trainSource.nextElement(train); trainingEvaluation.evaluateModelOnce((Classifier) classifier, trainInst); } testTimeElapsed = System.currentTimeMillis() - testTimeStart; } else { testTimeStart = System.currentTimeMillis(); trainingEvaluation.evaluateModel(classifier, trainSource.getDataSet(actualClassIndex)); testTimeElapsed = System.currentTimeMillis() - testTimeStart; } // Print the results of the training evaluation if (printMargins) { return trainingEvaluation.toCumulativeMarginDistributionString(); } else { if (!printClassifications) { text.append("\nTime taken to build model: " + Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds"); if (splitPercentage > 0) text.append("\nTime taken to test model on training split: "); else text.append("\nTime taken to test model on training data: "); text.append(Utils.doubleToString(testTimeElapsed / 1000.0, 2) + " seconds"); if (splitPercentage > 0) text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training" + " split ===\n", printComplexityStatistics)); else text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training" + " data ===\n", printComplexityStatistics)); if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + trainingEvaluation.toClassDetailsString()); } if (!noCrossValidation) text.append("\n\n" + trainingEvaluation.toMatrixString()); } } } } // Compute proper error estimates if (testSource != null) { // Testing is on the supplied test data testSource.reset(); test = testSource.getStructure(test.classIndex()); Instance testInst; while (testSource.hasMoreElements(test)) { testInst = testSource.nextElement(test); testingEvaluation.evaluateModelOnceAndRecordPrediction((Classifier) classifier, testInst); } if (splitPercentage > 0) { if (!printClassifications) { text.append("\n\n" + testingEvaluation.toSummaryString("=== Error on test split ===\n", printComplexityStatistics)); } } else { if (!printClassifications) { text.append("\n\n" + testingEvaluation.toSummaryString("=== Error on test data ===\n", printComplexityStatistics)); } } } else if (trainSource != null) { if (!noCrossValidation) { // Testing is via cross-validation on training data Random random = new Random(seed); // use untrained (!) classifier for cross-validation classifier = Classifier.makeCopy(classifierBackup); if (!printClassifications) { testingEvaluation.crossValidateModel(classifier, trainSource.getDataSet(actualClassIndex), folds, random); if (template.classAttribute().isNumeric()) { text.append("\n\n\n" + testingEvaluation.toSummaryString("=== Cross-validation ===\n", printComplexityStatistics)); } else { text.append("\n\n\n" + testingEvaluation.toSummaryString( "=== Stratified " + "cross-validation ===\n", printComplexityStatistics)); } } else { predsBuff = new StringBuffer(); predsBuff.append("\n=== Predictions under cross-validation ===\n\n"); testingEvaluation.crossValidateModel(classifier, trainSource.getDataSet(actualClassIndex), folds, random, predsBuff, attributesToOutput, new Boolean(printDistribution)); /* if (template.classAttribute().isNumeric()) { text.append("\n\n\n" + testingEvaluation. toSummaryString("=== Cross-validation ===\n", printComplexityStatistics)); } else { text.append("\n\n\n" + testingEvaluation. toSummaryString("=== Stratified " + "cross-validation ===\n", printComplexityStatistics)); } */ } } } if (template.classAttribute().isNominal()) { if (classStatistics && !noCrossValidation && !printClassifications) { text.append("\n\n" + testingEvaluation.toClassDetailsString()); } if (!noCrossValidation && !printClassifications) text.append("\n\n" + testingEvaluation.toMatrixString()); } // predictions from cross-validation? if (predsBuff != null) { text.append("\n" + predsBuff); } if ((thresholdFile.length() != 0) && template.classAttribute().isNominal()) { int labelIndex = 0; if (thresholdLabel.length() != 0) labelIndex = template.classAttribute().indexOfValue(thresholdLabel); if (labelIndex == -1) throw new IllegalArgumentException("Class label '" + thresholdLabel + "' is unknown!"); ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(testingEvaluation.predictions(), labelIndex); DataSink.write(thresholdFile, result); } return text.toString(); }
From source file:cotraining.copy.Evaluation_D.java
License:Open Source License
/** * Prints the predictions for the given dataset into a supplied StringBuffer * /*www . j a va 2 s . c o m*/ * @param classifier the classifier to use * @param train the training data * @param testSource the test set * @param classIndex the class index (1-based), if -1 ot does not * override the class index is stored in the data * file (by using the last attribute) * @param attributesToOutput the indices of the attributes to output * @param printDistribution prints the complete distribution for nominal * classes, not just the predicted value * @param text StringBuffer to hold the printed predictions * @throws Exception if test file cannot be opened */ public static void printClassifications(Classifier classifier, Instances train, DataSource testSource, int classIndex, Range attributesToOutput, boolean printDistribution, StringBuffer text) throws Exception { if (testSource != null) { Instances test = testSource.getStructure(); if (classIndex != -1) { test.setClassIndex(classIndex - 1); } else { if (test.classIndex() == -1) test.setClassIndex(test.numAttributes() - 1); } // print the header printClassificationsHeader(test, attributesToOutput, printDistribution, text); // print predictions int i = 0; testSource.reset(); test = testSource.getStructure(test.classIndex()); while (testSource.hasMoreElements(test)) { Instance inst = testSource.nextElement(test); text.append(predictionText(classifier, inst, i, attributesToOutput, printDistribution)); i++; } } // return text.toString(); }
From source file:cs.man.ac.uk.predict.Predictor.java
License:Open Source License
public static void makePredictionsEnsembleNew(String trainPath, String testPath, String resultPath) { System.out.println("Training set: " + trainPath); System.out.println("Test set: " + testPath); /**/*from w ww . j a v a2 s . com*/ * The ensemble classifiers. This is a heterogeneous ensemble. */ J48 learner1 = new J48(); SMO learner2 = new SMO(); NaiveBayes learner3 = new NaiveBayes(); MultilayerPerceptron learner5 = new MultilayerPerceptron(); System.out.println("Training Ensemble."); long startTime = System.nanoTime(); try { BufferedReader reader = new BufferedReader(new FileReader(trainPath)); Instances data = new Instances(reader); data.setClassIndex(data.numAttributes() - 1); System.out.println("Training data length: " + data.numInstances()); learner1.buildClassifier(data); learner2.buildClassifier(data); learner3.buildClassifier(data); learner5.buildClassifier(data); long endTime = System.nanoTime(); long nanoseconds = endTime - startTime; double seconds = (double) nanoseconds / 1000000000.0; System.out.println("Training Ensemble completed in " + nanoseconds + " (ns) or " + seconds + " (s)."); } catch (IOException e) { System.out.println("Could not train Ensemble classifier IOException on training data file."); } catch (Exception e) { System.out.println("Could not train Ensemble classifier Exception building model."); } try { String line = ""; // Read the file and display it line by line. BufferedReader in = null; // Read in and store each positive prediction in the tree map. try { //open stream to file in = new BufferedReader(new FileReader(testPath)); while ((line = in.readLine()) != null) { if (line.toLowerCase().contains("@data")) break; } } catch (Exception e) { } // A different ARFF loader used here (compared to above) as // the ARFF file may be extremely large. In which case the whole // file cannot be read in. Instead it is read in incrementally. ArffLoader loader = new ArffLoader(); loader.setFile(new File(testPath)); Instances data = loader.getStructure(); data.setClassIndex(data.numAttributes() - 1); System.out.println("Ensemble Classifier is ready."); System.out.println("Testing on all instances avaialable."); startTime = System.nanoTime(); int instanceNumber = 0; // label instances Instance current; while ((current = loader.getNextInstance(data)) != null) { instanceNumber += 1; line = in.readLine(); double classification1 = learner1.classifyInstance(current); double classification2 = learner2.classifyInstance(current); double classification3 = learner3.classifyInstance(current); double classification5 = learner5.classifyInstance(current); // All classifiers must agree. This is a very primitive ensemble strategy! if (classification1 == 1 && classification2 == 1 && classification3 == 1 && classification5 == 1) { if (line != null) { //System.out.println("Instance: "+instanceNumber+"\t"+line); //System.in.read(); } Writer.append(resultPath, instanceNumber + "\n"); } } in.close(); System.out.println("Test set instances: " + instanceNumber); long endTime = System.nanoTime(); long duration = endTime - startTime; double seconds = (double) duration / 1000000000.0; System.out.println("Testing Ensemble completed in " + duration + " (ns) or " + seconds + " (s)."); } catch (Exception e) { System.out.println("Could not test Ensemble classifier due to an error."); } }
From source file:cs.man.ac.uk.predict.Predictor.java
License:Open Source License
public static void makePredictionsEnsembleStream(String trainPath, String testPath, String resultPath) { System.out.println("Training set: " + trainPath); System.out.println("Test set: " + testPath); /**/*ww w . ja va 2 s .c om*/ * The ensemble classifiers. This is a heterogeneous ensemble. */ J48 learner1 = new J48(); SMO learner2 = new SMO(); NaiveBayes learner3 = new NaiveBayes(); MultilayerPerceptron learner5 = new MultilayerPerceptron(); System.out.println("Training Ensemble."); long startTime = System.nanoTime(); try { BufferedReader reader = new BufferedReader(new FileReader(trainPath)); Instances data = new Instances(reader); data.setClassIndex(data.numAttributes() - 1); System.out.println("Training data length: " + data.numInstances()); learner1.buildClassifier(data); learner2.buildClassifier(data); learner3.buildClassifier(data); learner5.buildClassifier(data); long endTime = System.nanoTime(); long nanoseconds = endTime - startTime; double seconds = (double) nanoseconds / 1000000000.0; System.out.println("Training Ensemble completed in " + nanoseconds + " (ns) or " + seconds + " (s)."); } catch (IOException e) { System.out.println("Could not train Ensemble classifier IOException on training data file."); } catch (Exception e) { System.out.println("Could not train Ensemble classifier Exception building model."); } try { // A different ARFF loader used here (compared to above) as // the ARFF file may be extremely large. In which case the whole // file cannot be read in. Instead it is read in incrementally. ArffLoader loader = new ArffLoader(); loader.setFile(new File(testPath)); Instances data = loader.getStructure(); data.setClassIndex(data.numAttributes() - 1); System.out.println("Ensemble Classifier is ready."); System.out.println("Testing on all instances avaialable."); startTime = System.nanoTime(); int instanceNumber = 0; // label instances Instance current; while ((current = loader.getNextInstance(data)) != null) { instanceNumber += 1; double classification1 = learner1.classifyInstance(current); double classification2 = learner2.classifyInstance(current); double classification3 = learner3.classifyInstance(current); double classification5 = learner5.classifyInstance(current); // All classifiers must agree. This is a very primitive ensemble strategy! if (classification1 == 1 && classification2 == 1 && classification3 == 1 && classification5 == 1) { Writer.append(resultPath, instanceNumber + "\n"); } } System.out.println("Test set instances: " + instanceNumber); long endTime = System.nanoTime(); long duration = endTime - startTime; double seconds = (double) duration / 1000000000.0; System.out.println("Testing Ensemble completed in " + duration + " (ns) or " + seconds + " (s)."); } catch (Exception e) { System.out.println("Could not test Ensemble classifier due to an error."); } }
From source file:cs.man.ac.uk.predict.Predictor.java
License:Open Source License
public static void makePredictionsJ48(String trainPath, String testPath, String resultPath) { /**/*from ww w . j a v a 2s .c om*/ * The decision tree classifier. */ J48 learner = new J48(); System.out.println("Training set: " + trainPath); System.out.println("Test set: " + testPath); System.out.println("Training J48"); long startTime = System.nanoTime(); try { BufferedReader reader = new BufferedReader(new FileReader(trainPath)); Instances data = new Instances(reader); data.setClassIndex(data.numAttributes() - 1); System.out.println("Training data length: " + data.numInstances()); learner.buildClassifier(data); long endTime = System.nanoTime(); long nanoseconds = endTime - startTime; double seconds = (double) nanoseconds / 1000000000.0; System.out.println("Training J48 completed in " + nanoseconds + " (ns) or " + seconds + " (s)"); } catch (IOException e) { System.out.println("Could not train J48 classifier IOException on training data file"); } catch (Exception e) { System.out.println("Could not train J48 classifier Exception building model"); } try { // Prepare data for testing //BufferedReader reader = new BufferedReader( new FileReader(testPath)); //Instances data = new Instances(reader); //data.setClassIndex(data.numAttributes() - 1); ArffLoader loader = new ArffLoader(); loader.setFile(new File(testPath)); Instances data = loader.getStructure(); data.setClassIndex(data.numAttributes() - 1); System.out.println("J48 Classifier is ready."); System.out.println("Testing on all instances avaialable."); System.out.println("Test set instances: " + data.numInstances()); startTime = System.nanoTime(); int instanceNumber = 0; // label instances Instance current; //for (int i = 0; i < data.numInstances(); i++) while ((current = loader.getNextInstance(data)) != null) { instanceNumber += 1; //double classification = learner.classifyInstance(data.instance(i)); double classification = learner.classifyInstance(current); //String instanceClass= Double.toString(data.instance(i).classValue()); if (classification == 1)// Predicted positive, actually negative { Writer.append(resultPath, instanceNumber + "\n"); } } long endTime = System.nanoTime(); long duration = endTime - startTime; double seconds = (double) duration / 1000000000.0; System.out.println("Testing J48 completed in " + duration + " (ns) or " + seconds + " (s)"); } catch (Exception e) { System.out.println("Could not test J48 classifier due to an error"); } }
From source file:csav2.Weka_additive.java
public void createTrainingFeatureFile1(String input) throws Exception { String file = "Classifier\\featurefile_additive_trial1.arff"; ArffLoader loader = new ArffLoader(); //ATTRIBUTES/*from www.j a va 2s. co m*/ Attribute attr[] = new Attribute[50]; //numeric attr[0] = new Attribute("Autosentiment"); //class FastVector classValue = new FastVector(3); classValue.addElement("p"); classValue.addElement("n"); classValue.addElement("o"); attr[1] = new Attribute("answer", classValue); FastVector attrs = new FastVector(); attrs.addElement(attr[0]); attrs.addElement(attr[1]); // Add Instances Instances dataset = new Instances("my_dataset", attrs, 0); if (new File(file).isFile()) { loader.setFile(new File(file)); dataset = loader.getDataSet(); } System.out.println("-----------------------------------------"); System.out.println(input); System.out.println("-----------------------------------------"); StringTokenizer tokenizer = new StringTokenizer(input); while (tokenizer.hasMoreTokens()) { Instance example = new Instance(2); for (int j = 0; j < 2; j++) { String st = tokenizer.nextToken(); System.out.println(j + " " + st); if (j == 0) example.setValue(attr[j], Float.parseFloat(st)); else if (j == 1) example.setValue(attr[j], st); else example.setValue(attr[j], Integer.parseInt(st)); } dataset.add(example); } //Save dataset ArffSaver saver = new ArffSaver(); saver.setInstances(dataset); saver.setFile(new File(file)); saver.writeBatch(); //Read dataset loader.setFile(new File(file)); dataset = loader.getDataSet(); //Build classifier dataset.setClassIndex(1); Classifier classifier = new J48(); classifier.buildClassifier(dataset); //Save classifier String file1 = "Classifier\\classifier_add_autosentiment.model"; OutputStream os = new FileOutputStream(file1); ObjectOutputStream objectOutputStream = new ObjectOutputStream(os); objectOutputStream.writeObject(classifier); // Comment out if not needed //Read classifier back InputStream is = new FileInputStream(file1); ObjectInputStream objectInputStream = new ObjectInputStream(is); classifier = (Classifier) objectInputStream.readObject(); objectInputStream.close(); //Evaluate resample if needed //dataset = dataset.resample(new Random(42)); //split to 70:30 learn and test set double percent = 70.0; int trainSize = (int) Math.round(dataset.numInstances() * percent / 100); int testSize = dataset.numInstances() - trainSize; Instances train = new Instances(dataset, 0, trainSize); Instances test = new Instances(dataset, trainSize, testSize); train.setClassIndex(1); test.setClassIndex(1); //Evaluate Evaluation eval = new Evaluation(dataset); //trainset eval.crossValidateModel(classifier, dataset, 10, new Random(1)); System.out.println("EVALUATION:\n" + eval.toSummaryString()); System.out.println("WEIGHTED MEASURE:\n" + eval.weightedFMeasure()); System.out.println("WEIGHTED PRECISION:\n" + eval.weightedPrecision()); System.out.println("WEIGHTED RECALL:\n" + eval.weightedRecall()); }
From source file:csav2.Weka_additive.java
public void createTrainingFeatureFile2(String input) throws Exception { String file = "Classifier\\featurefile_additive_trial2.arff"; ArffLoader loader = new ArffLoader(); //ATTRIBUTES//from w w w . j a v a 2s.co m Attribute attr[] = new Attribute[50]; //numeric attr[0] = new Attribute("Autosentiment"); attr[1] = new Attribute("PositiveMatch"); attr[2] = new Attribute("NegativeMatch"); //class FastVector classValue = new FastVector(3); classValue.addElement("p"); classValue.addElement("n"); classValue.addElement("o"); attr[3] = new Attribute("answer", classValue); FastVector attrs = new FastVector(); attrs.addElement(attr[0]); attrs.addElement(attr[1]); attrs.addElement(attr[2]); attrs.addElement(attr[3]); // Add Instances Instances dataset = new Instances("my_dataset", attrs, 0); if (new File(file).isFile()) { loader.setFile(new File(file)); dataset = loader.getDataSet(); } System.out.println("-----------------------------------------"); System.out.println(input); System.out.println("-----------------------------------------"); StringTokenizer tokenizer = new StringTokenizer(input); while (tokenizer.hasMoreTokens()) { Instance example = new Instance(4); for (int j = 0; j < 4; j++) { String st = tokenizer.nextToken(); System.out.println(j + " " + st); if (j == 0) example.setValue(attr[j], Float.parseFloat(st)); else if (j == 3) example.setValue(attr[j], st); else example.setValue(attr[j], Integer.parseInt(st)); } dataset.add(example); } //Save dataset ArffSaver saver = new ArffSaver(); saver.setInstances(dataset); saver.setFile(new File(file)); saver.writeBatch(); //Read dataset loader.setFile(new File(file)); dataset = loader.getDataSet(); //Build classifier dataset.setClassIndex(3); Classifier classifier = new J48(); classifier.buildClassifier(dataset); //Save classifier String file1 = "Classifier\\classifier_add_asAndpolarwords.model"; OutputStream os = new FileOutputStream(file1); ObjectOutputStream objectOutputStream = new ObjectOutputStream(os); objectOutputStream.writeObject(classifier); // Comment out if not needed //Read classifier back InputStream is = new FileInputStream(file1); ObjectInputStream objectInputStream = new ObjectInputStream(is); classifier = (Classifier) objectInputStream.readObject(); objectInputStream.close(); //Evaluate resample if needed //dataset = dataset.resample(new Random(42)); //split to 70:30 learn and test set double percent = 70.0; int trainSize = (int) Math.round(dataset.numInstances() * percent / 100); int testSize = dataset.numInstances() - trainSize; Instances train = new Instances(dataset, 0, trainSize); Instances test = new Instances(dataset, trainSize, testSize); train.setClassIndex(3); test.setClassIndex(3); //Evaluate Evaluation eval = new Evaluation(dataset); //trainset eval.crossValidateModel(classifier, dataset, 10, new Random(1)); System.out.println("EVALUATION:\n" + eval.toSummaryString()); System.out.println("WEIGHTED MEASURE:\n" + eval.weightedFMeasure()); System.out.println("WEIGHTED PRECISION:\n" + eval.weightedPrecision()); System.out.println("WEIGHTED RECALL:\n" + eval.weightedRecall()); }