Example usage for org.deeplearning4j.eval Evaluation truePositives

List of usage examples for org.deeplearning4j.eval Evaluation truePositives

Introduction

In this page you can find the example usage for org.deeplearning4j.eval Evaluation truePositives.

Prototype

public Map<Integer, Integer> truePositives() 

Source Link

Document

True positives: correctly rejected

Usage

From source file:com.sliit.neuralnetwork.RecurrentNN.java

public String trainModel(String modelName, String filePath, int outputs, int inputsTot) throws NeuralException {
    System.out.println("calling trainModel");
    try {//  www .  j  av  a 2  s .  co m

        System.out.println("Neural Network Training start");
        loadSaveNN(modelName, false);
        if (model == null) {

            buildModel();
        }

        File fileGeneral = new File(filePath);
        CSVLoader loader = new CSVLoader();
        loader.setSource(fileGeneral);
        Instances instances = loader.getDataSet();
        instances.setClassIndex(instances.numAttributes() - 1);
        StratifiedRemoveFolds stratified = new StratifiedRemoveFolds();
        String[] options = new String[6];
        options[0] = "-N";
        options[1] = Integer.toString(5);
        options[2] = "-F";
        options[3] = Integer.toString(1);
        options[4] = "-S";
        options[5] = Integer.toString(1);
        stratified.setOptions(options);
        stratified.setInputFormat(instances);
        stratified.setInvertSelection(false);
        Instances testInstances = Filter.useFilter(instances, stratified);
        stratified.setInvertSelection(true);
        Instances trainInstances = Filter.useFilter(instances, stratified);
        String directory = fileGeneral.getParent();
        CSVSaver saver = new CSVSaver();
        File trainFile = new File(directory + "/" + "normtrainadded.csv");
        File testFile = new File(directory + "/" + "normtestadded.csv");
        if (trainFile.exists()) {

            trainFile.delete();
        }
        trainFile.createNewFile();
        if (testFile.exists()) {

            testFile.delete();
        }
        testFile.createNewFile();
        saver.setFile(trainFile);
        saver.setInstances(trainInstances);
        saver.writeBatch();
        saver = new CSVSaver();
        saver.setFile(testFile);
        saver.setInstances(testInstances);
        saver.writeBatch();
        SequenceRecordReader recordReader = new CSVSequenceRecordReader(0, ",");
        recordReader.initialize(new org.datavec.api.split.FileSplit(trainFile));
        SequenceRecordReader testReader = new CSVSequenceRecordReader(0, ",");
        testReader.initialize(new org.datavec.api.split.FileSplit(testFile));
        DataSetIterator iterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator(
                recordReader, 2, outputs, inputsTot, false);
        DataSetIterator testIterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator(
                testReader, 2, outputs, inputsTot, false);
        roc = new ArrayList<Map<String, Double>>();
        String statMsg = "";
        Evaluation evaluation;

        for (int i = 0; i < 100; i++) {
            if (i % 2 == 0) {

                model.fit(iterator);
                evaluation = model.evaluate(testIterator);
            } else {

                model.fit(testIterator);
                evaluation = model.evaluate(iterator);
            }
            Map<String, Double> map = new HashMap<String, Double>();
            Map<Integer, Integer> falsePositives = evaluation.falsePositives();
            Map<Integer, Integer> trueNegatives = evaluation.trueNegatives();
            Map<Integer, Integer> truePositives = evaluation.truePositives();
            Map<Integer, Integer> falseNegatives = evaluation.falseNegatives();
            double fpr = falsePositives.get(1) / (falsePositives.get(1) + trueNegatives.get(1));
            double tpr = truePositives.get(1) / (truePositives.get(1) + falseNegatives.get(1));
            map.put("FPR", fpr);
            map.put("TPR", tpr);
            roc.add(map);
            statMsg = evaluation.stats();
            iterator.reset();
            testIterator.reset();
        }
        loadSaveNN(modelName, true);
        System.out.println("ROC " + roc);

        return statMsg;

    } catch (Exception e) {
        e.printStackTrace();
        System.out.println("Error ocuured while building neural netowrk :" + e.getMessage());
        throw new NeuralException(e.getLocalizedMessage(), e);
    }
}