Example usage for org.deeplearning4j.nn.multilayer MultiLayerNetwork output

List of usage examples for org.deeplearning4j.nn.multilayer MultiLayerNetwork output

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.multilayer MultiLayerNetwork output.

Prototype

public synchronized <T> T output(@NonNull INDArray inputs, INDArray inputMasks, INDArray labelMasks,
        @NonNull OutputAdapter<T> outputAdapter) 

Source Link

Document

This method uses provided OutputAdapter to return custom object built from INDArray PLEASE NOTE: This method uses dedicated Workspace for output generation to avoid redundant allocations

Usage

From source file:seqtest.Pair.java

public static void main(String[] args) throws Exception {
    downloadUCIData();/*from  ww  w .  ja  v  a 2s  .  co  m*/

    // ----- Load the training data -----
    //Note that we have 450 training files for features: train/features/0.csv through train/features/449.csv
    SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
    trainFeatures
            .initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
    SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
    trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));

    int miniBatchSize = 10;
    int numLabelClasses = 6;
    DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels,
            miniBatchSize, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    //Normalize the training data
    //DataNormalization normalizer = new NormalizerStandardize();
    //normalizer.fit(trainData);              //Collect training data statistics
    //trainData.reset();

    //Use previously collected statistics to normalize on-the-fly. Each DataSet returned by 'trainData' iterator will be normalized
    //trainData.setPreProcessor(normalizer);

    // ----- Load the test data -----
    //Same process as for the training data.
    SequenceRecordReader testFeatures = new CSVSequenceRecordReader();
    testFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
    SequenceRecordReader testLabels = new CSVSequenceRecordReader();
    testLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));

    DataSetIterator testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels, miniBatchSize,
            numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    //testData.setPreProcessor(normalizer);   //Note that we are using the exact same normalization process as the training data

    // ----- Configure the network -----
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) //Random number generator seed for improved repeatability. Optional.
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
            .weightInit(WeightInit.XAVIER).updater(Updater.NESTEROVS).momentum(0.9).learningRate(0.005)
            .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) //Not always required, but helps with this data set
            .gradientNormalizationThreshold(0.5).list(2)
            .layer(0, new GravesLSTM.Builder().activation("tanh").nIn(1).nOut(10).build())
            .layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation("softmax")
                    .nIn(10).nOut(numLabelClasses).build())
            .pretrain(false).backprop(true).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    net.setListeners(new ScoreIterationListener(20)); //Print the score (loss function value) every 20 iterations

    // ----- Train the network, evaluating the test set performance at each epoch -----
    int nEpochs = 40;
    String str = "Test set evaluation at epoch %d: Accuracy = %.2f, F1 = %.2f";
    for (int i = 0; i < nEpochs; i++) {
        net.fit(trainData);

        //Evaluate on the test set:
        Evaluation evaluation = new Evaluation();
        while (testData.hasNext()) {
            DataSet t = testData.next();
            INDArray features = t.getFeatureMatrix();
            INDArray lables = t.getLabels();
            INDArray inMask = t.getFeaturesMaskArray();
            INDArray outMask = t.getLabelsMaskArray();
            INDArray predicted = net.output(features, false, inMask, outMask);

            evaluation.evalTimeSeries(lables, predicted, outMask);
        }

        System.out.println(String.format(str, i, evaluation.accuracy(), evaluation.f1()));

        testData.reset();
        trainData.reset();
    }

    System.out.println("----- Example Complete -----");
}