Example usage for org.deeplearning4j.eval Evaluation evalTimeSeries

List of usage examples for org.deeplearning4j.eval Evaluation evalTimeSeries

Introduction

In this page you can find the example usage for org.deeplearning4j.eval Evaluation evalTimeSeries.

Prototype

@Deprecated
void evalTimeSeries(INDArray labels, INDArray predicted, INDArray labelsMaskArray);

Source Link

Usage

From source file:seqtest.Pair.java

public static void main(String[] args) throws Exception {
    downloadUCIData();//from  w  w w .  j a va 2 s.co m

    // ----- Load the training data -----
    //Note that we have 450 training files for features: train/features/0.csv through train/features/449.csv
    SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
    trainFeatures
            .initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
    SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
    trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));

    int miniBatchSize = 10;
    int numLabelClasses = 6;
    DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels,
            miniBatchSize, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    //Normalize the training data
    //DataNormalization normalizer = new NormalizerStandardize();
    //normalizer.fit(trainData);              //Collect training data statistics
    //trainData.reset();

    //Use previously collected statistics to normalize on-the-fly. Each DataSet returned by 'trainData' iterator will be normalized
    //trainData.setPreProcessor(normalizer);

    // ----- Load the test data -----
    //Same process as for the training data.
    SequenceRecordReader testFeatures = new CSVSequenceRecordReader();
    testFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
    SequenceRecordReader testLabels = new CSVSequenceRecordReader();
    testLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));

    DataSetIterator testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels, miniBatchSize,
            numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    //testData.setPreProcessor(normalizer);   //Note that we are using the exact same normalization process as the training data

    // ----- Configure the network -----
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) //Random number generator seed for improved repeatability. Optional.
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
            .weightInit(WeightInit.XAVIER).updater(Updater.NESTEROVS).momentum(0.9).learningRate(0.005)
            .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) //Not always required, but helps with this data set
            .gradientNormalizationThreshold(0.5).list(2)
            .layer(0, new GravesLSTM.Builder().activation("tanh").nIn(1).nOut(10).build())
            .layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation("softmax")
                    .nIn(10).nOut(numLabelClasses).build())
            .pretrain(false).backprop(true).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    net.setListeners(new ScoreIterationListener(20)); //Print the score (loss function value) every 20 iterations

    // ----- Train the network, evaluating the test set performance at each epoch -----
    int nEpochs = 40;
    String str = "Test set evaluation at epoch %d: Accuracy = %.2f, F1 = %.2f";
    for (int i = 0; i < nEpochs; i++) {
        net.fit(trainData);

        //Evaluate on the test set:
        Evaluation evaluation = new Evaluation();
        while (testData.hasNext()) {
            DataSet t = testData.next();
            INDArray features = t.getFeatureMatrix();
            INDArray lables = t.getLabels();
            INDArray inMask = t.getFeaturesMaskArray();
            INDArray outMask = t.getLabelsMaskArray();
            INDArray predicted = net.output(features, false, inMask, outMask);

            evaluation.evalTimeSeries(lables, predicted, outMask);
        }

        System.out.println(String.format(str, i, evaluation.accuracy(), evaluation.f1()));

        testData.reset();
        trainData.reset();
    }

    System.out.println("----- Example Complete -----");
}