Example usage for weka.core SerializationHelper write

List of usage examples for weka.core SerializationHelper write

Introduction

In this page you can find the example usage for weka.core SerializationHelper write.

Prototype

public static void write(OutputStream stream, Object o) throws Exception 

Source Link

Document

serializes the given object to the specified stream.

Usage

From source file:homemadeWEKA.java

public static void save_model(Classifier cls) throws Exception {
    SerializationHelper.write("j48.model", cls);
}

From source file:ann.ANN.java

public void saveModel(String modelname, Classifier model) {
    try {//ww w .  j  a v a2  s . co m
        SerializationHelper.write(modelname, model);
        System.out.println(modelname + " berhasil dibuat\n");
    } catch (Exception ex) {
        System.out.println(modelname + " tidak bisa dibuat\n");
    }
}

From source file:asap.CrossValidation.java

/**
 *
 * @param dataInput//from   w  w w. j ava 2  s .  c  o m
 * @param classIndex
 * @param removeIndices
 * @param cls
 * @param seed
 * @param folds
 * @param modelOutputFile
 * @return
 * @throws Exception
 */
public static String performCrossValidation(String dataInput, String classIndex, String removeIndices,
        AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception {

    PerformanceCounters.startTimer("cross-validation ST");

    PerformanceCounters.startTimer("cross-validation init ST");

    // loads data and set class index
    Instances data = DataSource.read(dataInput);
    String clsIndex = classIndex;

    switch (clsIndex) {
    case "first":
        data.setClassIndex(0);
        break;
    case "last":
        data.setClassIndex(data.numAttributes() - 1);
        break;
    default:
        try {
            data.setClassIndex(Integer.parseInt(clsIndex) - 1);
        } catch (NumberFormatException e) {
            data.setClassIndex(data.attribute(clsIndex).index());
        }
        break;
    }

    Remove removeFilter = new Remove();
    removeFilter.setAttributeIndices(removeIndices);
    removeFilter.setInputFormat(data);
    data = Filter.useFilter(data, removeFilter);

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval = new Evaluation(randData);
    Instances trainSets[] = new Instances[folds];
    Instances testSets[] = new Instances[folds];
    Classifier foldCls[] = new Classifier[folds];

    for (int n = 0; n < folds; n++) {
        trainSets[n] = randData.trainCV(folds, n);
        testSets[n] = randData.testCV(folds, n);
        foldCls[n] = AbstractClassifier.makeCopy(cls);
    }

    PerformanceCounters.stopTimer("cross-validation init ST");
    PerformanceCounters.startTimer("cross-validation folds+train ST");
    //paralelize!!:--------------------------------------------------------------
    for (int n = 0; n < folds; n++) {
        Instances train = trainSets[n];
        Instances test = testSets[n];

        // the above code is used by the StratifiedRemoveFolds filter, the
        // code below by the Explorer/Experimenter:
        // Instances train = randData.trainCV(folds, n, rand);
        // build and evaluate classifier
        Classifier clsCopy = foldCls[n];
        clsCopy.buildClassifier(train);
        eval.evaluateModel(clsCopy, test);
    }

    cls.buildClassifier(data);
    //until here!-----------------------------------------------------------------

    PerformanceCounters.stopTimer("cross-validation folds+train ST");
    PerformanceCounters.startTimer("cross-validation post ST");
    // output evaluation
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (!modelOutputFile.isEmpty()) {
        SerializationHelper.write(modelOutputFile, cls);
    }

    PerformanceCounters.stopTimer("cross-validation post ST");
    PerformanceCounters.stopTimer("cross-validation ST");

    return out;
}

From source file:asap.CrossValidation.java

/**
 *
 * @param dataInput/*from   w  w  w  .  ja  va2  s.c o  m*/
 * @param classIndex
 * @param removeIndices
 * @param cls
 * @param seed
 * @param folds
 * @param modelOutputFile
 * @return
 * @throws Exception
 */
public static String performCrossValidationMT(String dataInput, String classIndex, String removeIndices,
        AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception {

    PerformanceCounters.startTimer("cross-validation MT");

    PerformanceCounters.startTimer("cross-validation init MT");

    // loads data and set class index
    Instances data = DataSource.read(dataInput);
    String clsIndex = classIndex;

    switch (clsIndex) {
    case "first":
        data.setClassIndex(0);
        break;
    case "last":
        data.setClassIndex(data.numAttributes() - 1);
        break;
    default:
        try {
            data.setClassIndex(Integer.parseInt(clsIndex) - 1);
        } catch (NumberFormatException e) {
            data.setClassIndex(data.attribute(clsIndex).index());
        }
        break;
    }

    Remove removeFilter = new Remove();
    removeFilter.setAttributeIndices(removeIndices);
    removeFilter.setInputFormat(data);
    data = Filter.useFilter(data, removeFilter);

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval = new Evaluation(randData);
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                AbstractClassifier.makeCopy(cls)));

        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init MT");
    PerformanceCounters.startTimer("cross-validation folds+train MT");
    //paralelize!!:--------------------------------------------------------------
    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        //use the current thread to run the cross-validation instead of using the Thread instance created here:
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    cls.buildClassifier(data);

    for (Thread foldThread : foldThreads) {
        foldThread.join();
    }

    //until here!-----------------------------------------------------------------
    PerformanceCounters.stopTimer("cross-validation folds+train MT");
    PerformanceCounters.startTimer("cross-validation post MT");
    // evaluation for output:
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (!modelOutputFile.isEmpty()) {
        SerializationHelper.write(modelOutputFile, cls);
    }

    PerformanceCounters.stopTimer("cross-validation post MT");
    PerformanceCounters.stopTimer("cross-validation MT");
    return out;
}

From source file:asap.CrossValidation.java

static String performCrossValidationMT(Instances data, AbstractClassifier cls, int seed, int folds,
        String modelOutputFile) {

    PerformanceCounters.startTimer("cross-validation MT");

    PerformanceCounters.startTimer("cross-validation init MT");

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);//from   w ww.  j ava2 s.co m
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval;
    try {
        eval = new Evaluation(randData);
    } catch (Exception ex) {
        Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        return "Error creating evaluation instance for given data!";
    }
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        try {
            foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                    AbstractClassifier.makeCopy(cls)));
        } catch (Exception ex) {
            Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        }

        //TODO: use Config.getNumThreads() for limiting these::
        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init MT");
    PerformanceCounters.startTimer("cross-validation folds+train MT");
    //paralelize!!:--------------------------------------------------------------
    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    try {
        cls.buildClassifier(data);
    } catch (Exception ex) {
        Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
    }

    for (Thread foldThread : foldThreads) {
        try {
            foldThread.join();
        } catch (InterruptedException ex) {
            Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    //until here!-----------------------------------------------------------------
    PerformanceCounters.stopTimer("cross-validation folds+train MT");
    PerformanceCounters.startTimer("cross-validation post MT");
    // evaluation for output:
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (modelOutputFile != null) {
        if (!modelOutputFile.isEmpty()) {
            try {
                SerializationHelper.write(modelOutputFile, cls);
            } catch (Exception ex) {
                Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    PerformanceCounters.stopTimer("cross-validation post MT");
    PerformanceCounters.stopTimer("cross-validation MT");
    return out;
}

From source file:asap.NLPSystem.java

private String crossValidate(int seed, int folds, String modelOutputFile) {

    PerformanceCounters.startTimer("cross-validation");
    PerformanceCounters.startTimer("cross-validation init");

    AbstractClassifier abstractClassifier = (AbstractClassifier) classifier;
    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(trainingSet);
    randData.randomize(rand);//  w  ww.j a  v  a 2s. c  o m
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval;
    try {
        eval = new Evaluation(randData);
    } catch (Exception ex) {
        Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        return "Error creating evaluation instance for given data!";
    }
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        try {
            foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                    AbstractClassifier.makeCopy(abstractClassifier)));
        } catch (Exception ex) {
            Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        }

        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init");
    PerformanceCounters.startTimer("cross-validation folds+train");

    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    for (Thread foldThread : foldThreads) {
        while (foldThread.isAlive()) {
            try {
                foldThread.join();
            } catch (InterruptedException ex) {
                Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    PerformanceCounters.stopTimer("cross-validation folds+train");
    PerformanceCounters.startTimer("cross-validation post");
    // evaluation for output:
    String out = String.format(
            "\n=== Setup ===\nClassifier: %s %s\n" + "Dataset: %s\nFolds: %s\nSeed: %s\n\n%s\n",
            abstractClassifier.getClass().getName(), Utils.joinOptions(abstractClassifier.getOptions()),
            trainingSet.relationName(), folds, seed,
            eval.toSummaryString(String.format("=== %s-fold Cross-validation ===", folds), false));

    try {
        crossValidationPearsonsCorrelation = eval.correlationCoefficient();
    } catch (Exception ex) {
        Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
    }
    if (modelOutputFile != null) {
        if (!modelOutputFile.isEmpty()) {
            try {
                SerializationHelper.write(modelOutputFile, abstractClassifier);
            } catch (Exception ex) {
                Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    classifierBuiltWithCrossValidation = true;
    PerformanceCounters.stopTimer("cross-validation post");
    PerformanceCounters.stopTimer("cross-validation");
    return out;
}

From source file:classifier.SellerClassifier.java

public void rebuildModel(String dataset) {
    try {//from  w w  w  . jav  a 2 s.  co  m
        myInstances = startFeatureExtraction(loadData(dataset));
        myClassifier = new RandomForest();

        // build the model
        myClassifier.buildClassifier(myInstances);
        SerializationHelper.write(modelPath, myClassifier);
        SerializationHelper.write(refPath, myFilter);
    } catch (Exception ex) {
        Logger.getLogger(SellerClassifier.class.getName()).log(Level.SEVERE, null, ex);
    }
}

From source file:com.deafgoat.ml.prognosticator.AppClassifier.java

License:Apache License

/**
 * Perform cross-validation on data set/builds model
 * //  ww  w  .  ja  v  a2s.  com
 * @throws Exception
 */
public void crossValidate() throws Exception {
    // stratify nominal target class
    if (_trainInstances.classAttribute().isNominal()) {
        _trainInstances.stratify(_folds);
    }
    _eval = new Evaluation(_trainInstances);
    for (int n = 0; n < _folds; n++) {

        if (_logger.isDebugEnabled()) {
            _logger.debug("Cross validation fold: " + (n + 1));
        }
        _train = _trainInstances.trainCV(_folds, n);
        _test = _trainInstances.testCV(_folds, n);
        _clsCopy = AbstractClassifier.makeCopy(_cls);
        try {
            _clsCopy.buildClassifier(_train);
        } catch (Exception e) {
            _logger.debug(_config._classifier + " can not handle " + getAttributeType(_test.classAttribute())
                    + " class attributes");
        }

        try {
            _eval.evaluateModel(_clsCopy, _test);
        } catch (Exception e) {
            _logger.debug("Can not evaluate model");
        }
    }

    if (_config._writeToMongoDB) {
        _logger.info("Writing model to mongoDB");
        // save the trained model
        saveModel();
        // save CV performance of trained model
        writeToMongoDB(_eval);
    }

    if (_config._writeToFile) {
        _logger.info("Writing model to file");
        SerializationHelper.write(_config._modelFile, _clsCopy);
    }
}

From source file:com.deafgoat.ml.prognosticator.AppClassifier.java

License:Apache License

/**
 * Saves the trained model//from  ww  w.  j a  v  a2  s  . c  o m
 * 
 * @throws Exception
 *             If the model can not be saved
 */
public void saveModel() throws Exception {
    if (_logger.isDebugEnabled()) {
        _logger.debug("Serializing model");
    }

    if (_config._writeToMongoDB) {
        MongoResult mongoResult = new MongoResult(_config._host, _config._port, _config._db,
                _config._modelCollection);
        mongoResult.writeModel(_config._relation, _clsCopy);
        mongoResult.close();
    }

    if (_config._writeToFile) {
        SerializationHelper.write(_config._modelFile, _clsCopy);
    }
}

From source file:core.classification.Classifiers.java

License:Open Source License

public Classifiers saveSC(String filename1, String filename2, String filename3, String filename4,
        String filename5) throws Exception {
    SerializationHelper.write(filename1, SCA);
    SerializationHelper.write(filename2, SCB);
    SerializationHelper.write(filename3, SCC1);
    SerializationHelper.write(filename4, SCC2);
    SerializationHelper.write(filename5, SCC3);
    return this;
}