List of usage examples for org.deeplearning4j.nn.multilayer MultiLayerNetwork output
public synchronized <T> T output(@NonNull INDArray inputs, INDArray inputMasks, INDArray labelMasks, @NonNull OutputAdapter<T> outputAdapter)
From source file:seqtest.Pair.java
public static void main(String[] args) throws Exception { downloadUCIData();/*from ww w . ja v a 2s . co m*/ // ----- Load the training data ----- //Note that we have 450 training files for features: train/features/0.csv through train/features/449.csv SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); trainFeatures .initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); int miniBatchSize = 10; int numLabelClasses = 6; DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); //Normalize the training data //DataNormalization normalizer = new NormalizerStandardize(); //normalizer.fit(trainData); //Collect training data statistics //trainData.reset(); //Use previously collected statistics to normalize on-the-fly. Each DataSet returned by 'trainData' iterator will be normalized //trainData.setPreProcessor(normalizer); // ----- Load the test data ----- //Same process as for the training data. SequenceRecordReader testFeatures = new CSVSequenceRecordReader(); testFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); SequenceRecordReader testLabels = new CSVSequenceRecordReader(); testLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); DataSetIterator testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels, miniBatchSize, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); //testData.setPreProcessor(normalizer); //Note that we are using the exact same normalization process as the training data // ----- Configure the network ----- MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) //Random number generator seed for improved repeatability. Optional. .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .weightInit(WeightInit.XAVIER).updater(Updater.NESTEROVS).momentum(0.9).learningRate(0.005) .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) //Not always required, but helps with this data set .gradientNormalizationThreshold(0.5).list(2) .layer(0, new GravesLSTM.Builder().activation("tanh").nIn(1).nOut(10).build()) .layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation("softmax") .nIn(10).nOut(numLabelClasses).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); net.setListeners(new ScoreIterationListener(20)); //Print the score (loss function value) every 20 iterations // ----- Train the network, evaluating the test set performance at each epoch ----- int nEpochs = 40; String str = "Test set evaluation at epoch %d: Accuracy = %.2f, F1 = %.2f"; for (int i = 0; i < nEpochs; i++) { net.fit(trainData); //Evaluate on the test set: Evaluation evaluation = new Evaluation(); while (testData.hasNext()) { DataSet t = testData.next(); INDArray features = t.getFeatureMatrix(); INDArray lables = t.getLabels(); INDArray inMask = t.getFeaturesMaskArray(); INDArray outMask = t.getLabelsMaskArray(); INDArray predicted = net.output(features, false, inMask, outMask); evaluation.evalTimeSeries(lables, predicted, outMask); } System.out.println(String.format(str, i, evaluation.accuracy(), evaluation.f1())); testData.reset(); trainData.reset(); } System.out.println("----- Example Complete -----"); }