Example usage for org.deeplearning4j.earlystopping.saver InMemoryModelSaver InMemoryModelSaver

List of usage examples for org.deeplearning4j.earlystopping.saver InMemoryModelSaver InMemoryModelSaver

Introduction

In this page you can find the example usage for org.deeplearning4j.earlystopping.saver InMemoryModelSaver InMemoryModelSaver.

Prototype

InMemoryModelSaver

Source Link

Usage

From source file:com.heatonresearch.aifh.examples.ann.LearnAutoMPGBackprop.java

License:Apache License

/**
 * The main method./*from   w  w  w.jav a  2s  . c  om*/
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 0.01;

        // Setup training data.
        final InputStream istream = LearnAutoMPGBackprop.class.getResourceAsStream("/auto-mpg.data.csv");
        if (istream == null) {
            System.out.println("Cannot access data set, make sure the resources are available.");
            System.exit(1);
        }
        final NormalizeDataSet ds = NormalizeDataSet.load(istream);
        istream.close();

        // The following ranges are setup for the Auto MPG data set.  If you wish to normalize other files you will
        // need to modify the below function calls other files.

        // First remove some columns that we will not use:
        ds.deleteColumn(8); // Car name
        ds.deleteColumn(7); // Car origin
        ds.deleteColumn(6); // Year
        ds.deleteUnknowns();

        ds.normalizeZScore(1);
        ds.normalizeZScore(2);
        ds.normalizeZScore(3);
        ds.normalizeZScore(4);
        ds.normalizeZScore(5);

        DataSet next = ds.extractSupervised(1, 4, 0, 1);
        next.shuffle();

        // Training and validation data split
        int splitTrainNum = (int) (next.numExamples() * .75);
        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed));
        DataSet trainSet = testAndTrain.getTrain();
        DataSet validationSet = testAndTrain.getTest();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainSet.asList(), trainSet.numExamples());

        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationSet.numExamples());

        // Create neural network.
        int numInputs = next.numInputs();
        int numOutputs = next.numOutcomes();
        int numHiddenNodes = 50;

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9).list(2)
                .layer(0,
                        new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                                .weightInit(WeightInit.XAVIER).activation("relu").build())
                .layer(1,
                        new OutputLayer.Builder(LossFunction.MSE).weightInit(WeightInit.XAVIER)
                                .activation("identity").nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                .epochTerminationConditions(new MaxEpochsTerminationCondition(500)) //Max of 50 epochs
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(25))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            System.out.println(features + ":Prediction(" + predicted + "):Actual(" + labels + ")");
        }

    } catch (Exception ex) {
        ex.printStackTrace();
    }
}

From source file:com.heatonresearch.aifh.examples.ann.LearnDigitsBackprop.java

License:Apache License

/**
 * The main method./*ww  w  .j  a  v  a2s .co  m*/
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        System.out.println(trainingSet.get(0).getFeatures().size(1));
        System.out.println(validationSet.get(0).getFeatures().size(1));

        int numInputs = trainingReader.getNumCols() * trainingReader.getNumRows();
        int numOutputs = 10;
        int numHiddenNodes = 200;

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9).regularization(true).dropOut(0.50).list(2)
                .layer(0,
                        new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                                .weightInit(WeightInit.XAVIER).activation("relu").build())
                .layer(1,
                        new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                                .weightInit(WeightInit.XAVIER).activation("softmax").nIn(numHiddenNodes)
                                .nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch (Exception ex) {
        ex.printStackTrace();
    }

}

From source file:com.heatonresearch.aifh.examples.ann.LearnDigitsConv.java

License:Apache License

/**
 * The main method./*from   ww w  . j a va  2 s .co  m*/
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;
        int channels = 1;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        int numOutputs = 10;

        // Create neural network.
        MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .regularization(true).l2(0.0005).learningRate(0.01).weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS)
                .momentum(0.9).list(4)
                .layer(0,
                        new ConvolutionLayer.Builder(5, 5).nIn(channels).stride(1, 1).nOut(20).dropOut(0.5)
                                .activation("relu").build())
                .layer(1,
                        new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
                                .stride(2, 2).build())
                .layer(2, new DenseLayer.Builder().activation("relu").nOut(500).build())
                .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10)
                        .activation("softmax").build())
                .backprop(true).pretrain(false);

        new ConvolutionLayerSetup(builder, 28, 28, 1);
        MultiLayerConfiguration conf = builder.build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch (Exception ex) {
        ex.printStackTrace();
    }

}

From source file:com.heatonresearch.aifh.examples.ann.LearnDigitsDropout.java

License:Apache License

/**
 * The main method./*from  ww w.  ja va 2  s  . c  o  m*/
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        System.out.println(trainingSet.get(0).getFeatures().size(1));
        System.out.println(validationSet.get(0).getFeatures().size(1));

        int numInputs = trainingReader.getNumCols() * trainingReader.getNumRows();
        int numOutputs = 10;
        int numHiddenNodes = 100;

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9).list(2)
                .layer(0,
                        new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                                .weightInit(WeightInit.XAVIER).activation("relu").build())
                .layer(1,
                        new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                                .weightInit(WeightInit.XAVIER).activation("softmax").nIn(numHiddenNodes)
                                .nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch (Exception ex) {
        ex.printStackTrace();
    }

}

From source file:com.heatonresearch.aifh.examples.ann.LearnIrisBackprop.java

License:Apache License

/**
 * The main method./*from w  w w  . ja  v a 2s  .c o  m*/
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 0.1;
        int splitTrainNum = (int) (150 * .75);

        int numInputs = 4;
        int numOutputs = 3;
        int numHiddenNodes = 50;

        // Setup training data.
        final InputStream istream = LearnIrisBackprop.class.getResourceAsStream("/iris.csv");
        if (istream == null) {
            System.out.println("Cannot access data set, make sure the resources are available.");
            System.exit(1);
        }
        final NormalizeDataSet ds = NormalizeDataSet.load(istream);
        final CategoryMap species = ds.encodeOneOfN(4); // species is column 4
        istream.close();

        DataSet next = ds.extractSupervised(0, 4, 4, 3);
        next.shuffle();

        // Training and validation data split
        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed));
        DataSet trainSet = testAndTrain.getTrain();
        DataSet validationSet = testAndTrain.getTest();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainSet.asList(), trainSet.numExamples());

        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationSet.numExamples());

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9).list(2)
                .layer(0,
                        new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                                .weightInit(WeightInit.XAVIER).activation("relu").build())
                .layer(1,
                        new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                                .weightInit(WeightInit.XAVIER).activation("softmax").nIn(numHiddenNodes)
                                .nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                .epochTerminationConditions(new MaxEpochsTerminationCondition(500)) //Max of 50 epochs
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(25))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            System.out.println(features + ":Prediction(" + findSpecies(labels, species) + "):Actual("
                    + findSpecies(predicted, species) + ")" + predicted);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch (Exception ex) {
        ex.printStackTrace();
    }
}