List of usage examples for weka.filters.supervised.instance StratifiedRemoveFolds setInvertSelection
public void setInvertSelection(boolean inverse)
From source file:com.sliit.neuralnetwork.RecurrentNN.java
public String trainModel(String modelName, String filePath, int outputs, int inputsTot) throws NeuralException { System.out.println("calling trainModel"); try {/*from ww w . j a va 2s .c o m*/ System.out.println("Neural Network Training start"); loadSaveNN(modelName, false); if (model == null) { buildModel(); } File fileGeneral = new File(filePath); CSVLoader loader = new CSVLoader(); loader.setSource(fileGeneral); Instances instances = loader.getDataSet(); instances.setClassIndex(instances.numAttributes() - 1); StratifiedRemoveFolds stratified = new StratifiedRemoveFolds(); String[] options = new String[6]; options[0] = "-N"; options[1] = Integer.toString(5); options[2] = "-F"; options[3] = Integer.toString(1); options[4] = "-S"; options[5] = Integer.toString(1); stratified.setOptions(options); stratified.setInputFormat(instances); stratified.setInvertSelection(false); Instances testInstances = Filter.useFilter(instances, stratified); stratified.setInvertSelection(true); Instances trainInstances = Filter.useFilter(instances, stratified); String directory = fileGeneral.getParent(); CSVSaver saver = new CSVSaver(); File trainFile = new File(directory + "/" + "normtrainadded.csv"); File testFile = new File(directory + "/" + "normtestadded.csv"); if (trainFile.exists()) { trainFile.delete(); } trainFile.createNewFile(); if (testFile.exists()) { testFile.delete(); } testFile.createNewFile(); saver.setFile(trainFile); saver.setInstances(trainInstances); saver.writeBatch(); saver = new CSVSaver(); saver.setFile(testFile); saver.setInstances(testInstances); saver.writeBatch(); SequenceRecordReader recordReader = new CSVSequenceRecordReader(0, ","); recordReader.initialize(new org.datavec.api.split.FileSplit(trainFile)); SequenceRecordReader testReader = new CSVSequenceRecordReader(0, ","); testReader.initialize(new org.datavec.api.split.FileSplit(testFile)); DataSetIterator iterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator( recordReader, 2, outputs, inputsTot, false); DataSetIterator testIterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator( testReader, 2, outputs, inputsTot, false); roc = new ArrayList<Map<String, Double>>(); String statMsg = ""; Evaluation evaluation; for (int i = 0; i < 100; i++) { if (i % 2 == 0) { model.fit(iterator); evaluation = model.evaluate(testIterator); } else { model.fit(testIterator); evaluation = model.evaluate(iterator); } Map<String, Double> map = new HashMap<String, Double>(); Map<Integer, Integer> falsePositives = evaluation.falsePositives(); Map<Integer, Integer> trueNegatives = evaluation.trueNegatives(); Map<Integer, Integer> truePositives = evaluation.truePositives(); Map<Integer, Integer> falseNegatives = evaluation.falseNegatives(); double fpr = falsePositives.get(1) / (falsePositives.get(1) + trueNegatives.get(1)); double tpr = truePositives.get(1) / (truePositives.get(1) + falseNegatives.get(1)); map.put("FPR", fpr); map.put("TPR", tpr); roc.add(map); statMsg = evaluation.stats(); iterator.reset(); testIterator.reset(); } loadSaveNN(modelName, true); System.out.println("ROC " + roc); return statMsg; } catch (Exception e) { e.printStackTrace(); System.out.println("Error ocuured while building neural netowrk :" + e.getMessage()); throw new NeuralException(e.getLocalizedMessage(), e); } }