Example usage for weka.filters.supervised.instance StratifiedRemoveFolds StratifiedRemoveFolds

List of usage examples for weka.filters.supervised.instance StratifiedRemoveFolds StratifiedRemoveFolds

Introduction

In this page you can find the example usage for weka.filters.supervised.instance StratifiedRemoveFolds StratifiedRemoveFolds.

Prototype

StratifiedRemoveFolds

Source Link

Usage

From source file:com.deafgoat.ml.prognosticator.InstancesFilter.java

License:Apache License

/**
 * Applies a filter to remove stratified folds from the set of instances
 * /*from  w ww.  ja va2s . c o m*/
 * @param fold
 *            The fold number to pick for remove
 * @param numFolds
 *            The number of folds for a stratified cross-validation
 * @param invert
 *            Flag indicating whether to remove this fold or all others
 * @throws Exception
 *             If filter could not be applied
 */
public void removeStratifiedFoldsFilter(Integer fold, Integer numFolds, boolean invert) throws Exception {
    if (_logger.isDebugEnabled()) {
        _logger.debug("Applying stratified remove folds filter");
    }
    StratifiedRemoveFolds srf = new StratifiedRemoveFolds();
    String[] options;
    if (invert) {
        options = new String[6];
        options[0] = "-S";
        options[1] = "-9";
        options[2] = "-N";
        options[3] = numFolds.toString();
        options[4] = "-F";
        options[5] = fold.toString();
    } else {
        options = new String[7];
        options[0] = "-S";
        options[1] = "-9";
        options[2] = "-V";
        options[3] = "-N";
        options[4] = numFolds.toString();
        options[5] = "-F";
        options[6] = fold.toString();
    }
    srf.setOptions(options);
    srf.setInputFormat(_instances);
    _instances = Filter.useFilter(_instances, srf);
}

From source file:com.sliit.neuralnetwork.RecurrentNN.java

public String trainModel(String modelName, String filePath, int outputs, int inputsTot) throws NeuralException {
    System.out.println("calling trainModel");
    try {/*w w w  .  j  a  v  a  2 s .  c o m*/

        System.out.println("Neural Network Training start");
        loadSaveNN(modelName, false);
        if (model == null) {

            buildModel();
        }

        File fileGeneral = new File(filePath);
        CSVLoader loader = new CSVLoader();
        loader.setSource(fileGeneral);
        Instances instances = loader.getDataSet();
        instances.setClassIndex(instances.numAttributes() - 1);
        StratifiedRemoveFolds stratified = new StratifiedRemoveFolds();
        String[] options = new String[6];
        options[0] = "-N";
        options[1] = Integer.toString(5);
        options[2] = "-F";
        options[3] = Integer.toString(1);
        options[4] = "-S";
        options[5] = Integer.toString(1);
        stratified.setOptions(options);
        stratified.setInputFormat(instances);
        stratified.setInvertSelection(false);
        Instances testInstances = Filter.useFilter(instances, stratified);
        stratified.setInvertSelection(true);
        Instances trainInstances = Filter.useFilter(instances, stratified);
        String directory = fileGeneral.getParent();
        CSVSaver saver = new CSVSaver();
        File trainFile = new File(directory + "/" + "normtrainadded.csv");
        File testFile = new File(directory + "/" + "normtestadded.csv");
        if (trainFile.exists()) {

            trainFile.delete();
        }
        trainFile.createNewFile();
        if (testFile.exists()) {

            testFile.delete();
        }
        testFile.createNewFile();
        saver.setFile(trainFile);
        saver.setInstances(trainInstances);
        saver.writeBatch();
        saver = new CSVSaver();
        saver.setFile(testFile);
        saver.setInstances(testInstances);
        saver.writeBatch();
        SequenceRecordReader recordReader = new CSVSequenceRecordReader(0, ",");
        recordReader.initialize(new org.datavec.api.split.FileSplit(trainFile));
        SequenceRecordReader testReader = new CSVSequenceRecordReader(0, ",");
        testReader.initialize(new org.datavec.api.split.FileSplit(testFile));
        DataSetIterator iterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator(
                recordReader, 2, outputs, inputsTot, false);
        DataSetIterator testIterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator(
                testReader, 2, outputs, inputsTot, false);
        roc = new ArrayList<Map<String, Double>>();
        String statMsg = "";
        Evaluation evaluation;

        for (int i = 0; i < 100; i++) {
            if (i % 2 == 0) {

                model.fit(iterator);
                evaluation = model.evaluate(testIterator);
            } else {

                model.fit(testIterator);
                evaluation = model.evaluate(iterator);
            }
            Map<String, Double> map = new HashMap<String, Double>();
            Map<Integer, Integer> falsePositives = evaluation.falsePositives();
            Map<Integer, Integer> trueNegatives = evaluation.trueNegatives();
            Map<Integer, Integer> truePositives = evaluation.truePositives();
            Map<Integer, Integer> falseNegatives = evaluation.falseNegatives();
            double fpr = falsePositives.get(1) / (falsePositives.get(1) + trueNegatives.get(1));
            double tpr = truePositives.get(1) / (truePositives.get(1) + falseNegatives.get(1));
            map.put("FPR", fpr);
            map.put("TPR", tpr);
            roc.add(map);
            statMsg = evaluation.stats();
            iterator.reset();
            testIterator.reset();
        }
        loadSaveNN(modelName, true);
        System.out.println("ROC " + roc);

        return statMsg;

    } catch (Exception e) {
        e.printStackTrace();
        System.out.println("Error ocuured while building neural netowrk :" + e.getMessage());
        throw new NeuralException(e.getLocalizedMessage(), e);
    }
}