Example usage for weka.filters.supervised.instance StratifiedRemoveFolds setInputFormat

List of usage examples for weka.filters.supervised.instance StratifiedRemoveFolds setInputFormat

Introduction

In this page you can find the example usage for weka.filters.supervised.instance StratifiedRemoveFolds setInputFormat.

Prototype

@Override
public boolean setInputFormat(Instances instanceInfo) throws Exception 

Source Link

Document

Sets the format of the input instances.

Usage

From source file:com.deafgoat.ml.prognosticator.InstancesFilter.java

License:Apache License

/**
 * Applies a filter to remove stratified folds from the set of instances
 * //from ww  w. j  a v  a2 s . c  o  m
 * @param fold
 *            The fold number to pick for remove
 * @param numFolds
 *            The number of folds for a stratified cross-validation
 * @param invert
 *            Flag indicating whether to remove this fold or all others
 * @throws Exception
 *             If filter could not be applied
 */
public void removeStratifiedFoldsFilter(Integer fold, Integer numFolds, boolean invert) throws Exception {
    if (_logger.isDebugEnabled()) {
        _logger.debug("Applying stratified remove folds filter");
    }
    StratifiedRemoveFolds srf = new StratifiedRemoveFolds();
    String[] options;
    if (invert) {
        options = new String[6];
        options[0] = "-S";
        options[1] = "-9";
        options[2] = "-N";
        options[3] = numFolds.toString();
        options[4] = "-F";
        options[5] = fold.toString();
    } else {
        options = new String[7];
        options[0] = "-S";
        options[1] = "-9";
        options[2] = "-V";
        options[3] = "-N";
        options[4] = numFolds.toString();
        options[5] = "-F";
        options[6] = fold.toString();
    }
    srf.setOptions(options);
    srf.setInputFormat(_instances);
    _instances = Filter.useFilter(_instances, srf);
}

From source file:com.sliit.neuralnetwork.RecurrentNN.java

public String trainModel(String modelName, String filePath, int outputs, int inputsTot) throws NeuralException {
    System.out.println("calling trainModel");
    try {//from ww  w .  j  ava2  s . c o m

        System.out.println("Neural Network Training start");
        loadSaveNN(modelName, false);
        if (model == null) {

            buildModel();
        }

        File fileGeneral = new File(filePath);
        CSVLoader loader = new CSVLoader();
        loader.setSource(fileGeneral);
        Instances instances = loader.getDataSet();
        instances.setClassIndex(instances.numAttributes() - 1);
        StratifiedRemoveFolds stratified = new StratifiedRemoveFolds();
        String[] options = new String[6];
        options[0] = "-N";
        options[1] = Integer.toString(5);
        options[2] = "-F";
        options[3] = Integer.toString(1);
        options[4] = "-S";
        options[5] = Integer.toString(1);
        stratified.setOptions(options);
        stratified.setInputFormat(instances);
        stratified.setInvertSelection(false);
        Instances testInstances = Filter.useFilter(instances, stratified);
        stratified.setInvertSelection(true);
        Instances trainInstances = Filter.useFilter(instances, stratified);
        String directory = fileGeneral.getParent();
        CSVSaver saver = new CSVSaver();
        File trainFile = new File(directory + "/" + "normtrainadded.csv");
        File testFile = new File(directory + "/" + "normtestadded.csv");
        if (trainFile.exists()) {

            trainFile.delete();
        }
        trainFile.createNewFile();
        if (testFile.exists()) {

            testFile.delete();
        }
        testFile.createNewFile();
        saver.setFile(trainFile);
        saver.setInstances(trainInstances);
        saver.writeBatch();
        saver = new CSVSaver();
        saver.setFile(testFile);
        saver.setInstances(testInstances);
        saver.writeBatch();
        SequenceRecordReader recordReader = new CSVSequenceRecordReader(0, ",");
        recordReader.initialize(new org.datavec.api.split.FileSplit(trainFile));
        SequenceRecordReader testReader = new CSVSequenceRecordReader(0, ",");
        testReader.initialize(new org.datavec.api.split.FileSplit(testFile));
        DataSetIterator iterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator(
                recordReader, 2, outputs, inputsTot, false);
        DataSetIterator testIterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator(
                testReader, 2, outputs, inputsTot, false);
        roc = new ArrayList<Map<String, Double>>();
        String statMsg = "";
        Evaluation evaluation;

        for (int i = 0; i < 100; i++) {
            if (i % 2 == 0) {

                model.fit(iterator);
                evaluation = model.evaluate(testIterator);
            } else {

                model.fit(testIterator);
                evaluation = model.evaluate(iterator);
            }
            Map<String, Double> map = new HashMap<String, Double>();
            Map<Integer, Integer> falsePositives = evaluation.falsePositives();
            Map<Integer, Integer> trueNegatives = evaluation.trueNegatives();
            Map<Integer, Integer> truePositives = evaluation.truePositives();
            Map<Integer, Integer> falseNegatives = evaluation.falseNegatives();
            double fpr = falsePositives.get(1) / (falsePositives.get(1) + trueNegatives.get(1));
            double tpr = truePositives.get(1) / (truePositives.get(1) + falseNegatives.get(1));
            map.put("FPR", fpr);
            map.put("TPR", tpr);
            roc.add(map);
            statMsg = evaluation.stats();
            iterator.reset();
            testIterator.reset();
        }
        loadSaveNN(modelName, true);
        System.out.println("ROC " + roc);

        return statMsg;

    } catch (Exception e) {
        e.printStackTrace();
        System.out.println("Error ocuured while building neural netowrk :" + e.getMessage());
        throw new NeuralException(e.getLocalizedMessage(), e);
    }
}