List of usage examples for weka.core SerializationHelper write
public static void write(OutputStream stream, Object o) throws Exception
From source file:homemadeWEKA.java
public static void save_model(Classifier cls) throws Exception { SerializationHelper.write("j48.model", cls); }
From source file:ann.ANN.java
public void saveModel(String modelname, Classifier model) { try {//ww w . j a v a2 s . co m SerializationHelper.write(modelname, model); System.out.println(modelname + " berhasil dibuat\n"); } catch (Exception ex) { System.out.println(modelname + " tidak bisa dibuat\n"); } }
From source file:asap.CrossValidation.java
/** * * @param dataInput//from w w w. j ava 2 s . c o m * @param classIndex * @param removeIndices * @param cls * @param seed * @param folds * @param modelOutputFile * @return * @throws Exception */ public static String performCrossValidation(String dataInput, String classIndex, String removeIndices, AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception { PerformanceCounters.startTimer("cross-validation ST"); PerformanceCounters.startTimer("cross-validation init ST"); // loads data and set class index Instances data = DataSource.read(dataInput); String clsIndex = classIndex; switch (clsIndex) { case "first": data.setClassIndex(0); break; case "last": data.setClassIndex(data.numAttributes() - 1); break; default: try { data.setClassIndex(Integer.parseInt(clsIndex) - 1); } catch (NumberFormatException e) { data.setClassIndex(data.attribute(clsIndex).index()); } break; } Remove removeFilter = new Remove(); removeFilter.setAttributeIndices(removeIndices); removeFilter.setInputFormat(data); data = Filter.useFilter(data, removeFilter); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval = new Evaluation(randData); Instances trainSets[] = new Instances[folds]; Instances testSets[] = new Instances[folds]; Classifier foldCls[] = new Classifier[folds]; for (int n = 0; n < folds; n++) { trainSets[n] = randData.trainCV(folds, n); testSets[n] = randData.testCV(folds, n); foldCls[n] = AbstractClassifier.makeCopy(cls); } PerformanceCounters.stopTimer("cross-validation init ST"); PerformanceCounters.startTimer("cross-validation folds+train ST"); //paralelize!!:-------------------------------------------------------------- for (int n = 0; n < folds; n++) { Instances train = trainSets[n]; Instances test = testSets[n]; // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier Classifier clsCopy = foldCls[n]; clsCopy.buildClassifier(train); eval.evaluateModel(clsCopy, test); } cls.buildClassifier(data); //until here!----------------------------------------------------------------- PerformanceCounters.stopTimer("cross-validation folds+train ST"); PerformanceCounters.startTimer("cross-validation post ST"); // output evaluation String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: " + folds + "\n" + "Seed: " + seed + "\n" + "\n" + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n"; if (!modelOutputFile.isEmpty()) { SerializationHelper.write(modelOutputFile, cls); } PerformanceCounters.stopTimer("cross-validation post ST"); PerformanceCounters.stopTimer("cross-validation ST"); return out; }
From source file:asap.CrossValidation.java
/** * * @param dataInput/*from w w w . ja va2 s.c o m*/ * @param classIndex * @param removeIndices * @param cls * @param seed * @param folds * @param modelOutputFile * @return * @throws Exception */ public static String performCrossValidationMT(String dataInput, String classIndex, String removeIndices, AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception { PerformanceCounters.startTimer("cross-validation MT"); PerformanceCounters.startTimer("cross-validation init MT"); // loads data and set class index Instances data = DataSource.read(dataInput); String clsIndex = classIndex; switch (clsIndex) { case "first": data.setClassIndex(0); break; case "last": data.setClassIndex(data.numAttributes() - 1); break; default: try { data.setClassIndex(Integer.parseInt(clsIndex) - 1); } catch (NumberFormatException e) { data.setClassIndex(data.attribute(clsIndex).index()); } break; } Remove removeFilter = new Remove(); removeFilter.setAttributeIndices(removeIndices); removeFilter.setInputFormat(data); data = Filter.useFilter(data, removeFilter); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval = new Evaluation(randData); List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>()); List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>()); for (int n = 0; n < folds; n++) { foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n), AbstractClassifier.makeCopy(cls))); if (n < Config.getNumThreads() - 1) { Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval)); foldThreads.add(foldThread); } } PerformanceCounters.stopTimer("cross-validation init MT"); PerformanceCounters.startTimer("cross-validation folds+train MT"); //paralelize!!:-------------------------------------------------------------- if (Config.getNumThreads() > 1) { for (Thread foldThread : foldThreads) { foldThread.start(); } } else { //use the current thread to run the cross-validation instead of using the Thread instance created here: new CrossValidationFoldThread(0, foldSets, eval).run(); } cls.buildClassifier(data); for (Thread foldThread : foldThreads) { foldThread.join(); } //until here!----------------------------------------------------------------- PerformanceCounters.stopTimer("cross-validation folds+train MT"); PerformanceCounters.startTimer("cross-validation post MT"); // evaluation for output: String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: " + folds + "\n" + "Seed: " + seed + "\n" + "\n" + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n"; if (!modelOutputFile.isEmpty()) { SerializationHelper.write(modelOutputFile, cls); } PerformanceCounters.stopTimer("cross-validation post MT"); PerformanceCounters.stopTimer("cross-validation MT"); return out; }
From source file:asap.CrossValidation.java
static String performCrossValidationMT(Instances data, AbstractClassifier cls, int seed, int folds, String modelOutputFile) { PerformanceCounters.startTimer("cross-validation MT"); PerformanceCounters.startTimer("cross-validation init MT"); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand);//from w ww. j ava2 s.co m if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval; try { eval = new Evaluation(randData); } catch (Exception ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); return "Error creating evaluation instance for given data!"; } List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>()); List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>()); for (int n = 0; n < folds; n++) { try { foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n), AbstractClassifier.makeCopy(cls))); } catch (Exception ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); } //TODO: use Config.getNumThreads() for limiting these:: if (n < Config.getNumThreads() - 1) { Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval)); foldThreads.add(foldThread); } } PerformanceCounters.stopTimer("cross-validation init MT"); PerformanceCounters.startTimer("cross-validation folds+train MT"); //paralelize!!:-------------------------------------------------------------- if (Config.getNumThreads() > 1) { for (Thread foldThread : foldThreads) { foldThread.start(); } } else { new CrossValidationFoldThread(0, foldSets, eval).run(); } try { cls.buildClassifier(data); } catch (Exception ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); } for (Thread foldThread : foldThreads) { try { foldThread.join(); } catch (InterruptedException ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); } } //until here!----------------------------------------------------------------- PerformanceCounters.stopTimer("cross-validation folds+train MT"); PerformanceCounters.startTimer("cross-validation post MT"); // evaluation for output: String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: " + folds + "\n" + "Seed: " + seed + "\n" + "\n" + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n"; if (modelOutputFile != null) { if (!modelOutputFile.isEmpty()) { try { SerializationHelper.write(modelOutputFile, cls); } catch (Exception ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); } } } PerformanceCounters.stopTimer("cross-validation post MT"); PerformanceCounters.stopTimer("cross-validation MT"); return out; }
From source file:asap.NLPSystem.java
private String crossValidate(int seed, int folds, String modelOutputFile) { PerformanceCounters.startTimer("cross-validation"); PerformanceCounters.startTimer("cross-validation init"); AbstractClassifier abstractClassifier = (AbstractClassifier) classifier; // randomize data Random rand = new Random(seed); Instances randData = new Instances(trainingSet); randData.randomize(rand);// w ww.j a v a 2s. c o m if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval; try { eval = new Evaluation(randData); } catch (Exception ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); return "Error creating evaluation instance for given data!"; } List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>()); List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>()); for (int n = 0; n < folds; n++) { try { foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n), AbstractClassifier.makeCopy(abstractClassifier))); } catch (Exception ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); } if (n < Config.getNumThreads() - 1) { Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval)); foldThreads.add(foldThread); } } PerformanceCounters.stopTimer("cross-validation init"); PerformanceCounters.startTimer("cross-validation folds+train"); if (Config.getNumThreads() > 1) { for (Thread foldThread : foldThreads) { foldThread.start(); } } else { new CrossValidationFoldThread(0, foldSets, eval).run(); } for (Thread foldThread : foldThreads) { while (foldThread.isAlive()) { try { foldThread.join(); } catch (InterruptedException ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); } } } PerformanceCounters.stopTimer("cross-validation folds+train"); PerformanceCounters.startTimer("cross-validation post"); // evaluation for output: String out = String.format( "\n=== Setup ===\nClassifier: %s %s\n" + "Dataset: %s\nFolds: %s\nSeed: %s\n\n%s\n", abstractClassifier.getClass().getName(), Utils.joinOptions(abstractClassifier.getOptions()), trainingSet.relationName(), folds, seed, eval.toSummaryString(String.format("=== %s-fold Cross-validation ===", folds), false)); try { crossValidationPearsonsCorrelation = eval.correlationCoefficient(); } catch (Exception ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); } if (modelOutputFile != null) { if (!modelOutputFile.isEmpty()) { try { SerializationHelper.write(modelOutputFile, abstractClassifier); } catch (Exception ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); } } } classifierBuiltWithCrossValidation = true; PerformanceCounters.stopTimer("cross-validation post"); PerformanceCounters.stopTimer("cross-validation"); return out; }
From source file:classifier.SellerClassifier.java
public void rebuildModel(String dataset) { try {//from w w w . jav a 2 s. co m myInstances = startFeatureExtraction(loadData(dataset)); myClassifier = new RandomForest(); // build the model myClassifier.buildClassifier(myInstances); SerializationHelper.write(modelPath, myClassifier); SerializationHelper.write(refPath, myFilter); } catch (Exception ex) { Logger.getLogger(SellerClassifier.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:com.deafgoat.ml.prognosticator.AppClassifier.java
License:Apache License
/** * Perform cross-validation on data set/builds model * // ww w . ja v a2s. com * @throws Exception */ public void crossValidate() throws Exception { // stratify nominal target class if (_trainInstances.classAttribute().isNominal()) { _trainInstances.stratify(_folds); } _eval = new Evaluation(_trainInstances); for (int n = 0; n < _folds; n++) { if (_logger.isDebugEnabled()) { _logger.debug("Cross validation fold: " + (n + 1)); } _train = _trainInstances.trainCV(_folds, n); _test = _trainInstances.testCV(_folds, n); _clsCopy = AbstractClassifier.makeCopy(_cls); try { _clsCopy.buildClassifier(_train); } catch (Exception e) { _logger.debug(_config._classifier + " can not handle " + getAttributeType(_test.classAttribute()) + " class attributes"); } try { _eval.evaluateModel(_clsCopy, _test); } catch (Exception e) { _logger.debug("Can not evaluate model"); } } if (_config._writeToMongoDB) { _logger.info("Writing model to mongoDB"); // save the trained model saveModel(); // save CV performance of trained model writeToMongoDB(_eval); } if (_config._writeToFile) { _logger.info("Writing model to file"); SerializationHelper.write(_config._modelFile, _clsCopy); } }
From source file:com.deafgoat.ml.prognosticator.AppClassifier.java
License:Apache License
/** * Saves the trained model//from ww w. j a v a2 s . c o m * * @throws Exception * If the model can not be saved */ public void saveModel() throws Exception { if (_logger.isDebugEnabled()) { _logger.debug("Serializing model"); } if (_config._writeToMongoDB) { MongoResult mongoResult = new MongoResult(_config._host, _config._port, _config._db, _config._modelCollection); mongoResult.writeModel(_config._relation, _clsCopy); mongoResult.close(); } if (_config._writeToFile) { SerializationHelper.write(_config._modelFile, _clsCopy); } }
From source file:core.classification.Classifiers.java
License:Open Source License
public Classifiers saveSC(String filename1, String filename2, String filename3, String filename4, String filename5) throws Exception { SerializationHelper.write(filename1, SCA); SerializationHelper.write(filename2, SCB); SerializationHelper.write(filename3, SCC1); SerializationHelper.write(filename4, SCC2); SerializationHelper.write(filename5, SCC3); return this; }