Example usage for org.deeplearning4j.earlystopping.termination MaxEpochsTerminationCondition MaxEpochsTerminationCondition

List of usage examples for org.deeplearning4j.earlystopping.termination MaxEpochsTerminationCondition MaxEpochsTerminationCondition

Introduction

In this page you can find the example usage for org.deeplearning4j.earlystopping.termination MaxEpochsTerminationCondition MaxEpochsTerminationCondition.

Prototype

@JsonCreator
    public MaxEpochsTerminationCondition(int maxEpochs) 

Source Link

Usage

From source file:com.heatonresearch.aifh.examples.ann.LearnAutoMPGBackprop.java

License:Apache License

/**
 * The main method.//  w w  w.  ja v  a2  s  . co  m
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 0.01;

        // Setup training data.
        final InputStream istream = LearnAutoMPGBackprop.class.getResourceAsStream("/auto-mpg.data.csv");
        if (istream == null) {
            System.out.println("Cannot access data set, make sure the resources are available.");
            System.exit(1);
        }
        final NormalizeDataSet ds = NormalizeDataSet.load(istream);
        istream.close();

        // The following ranges are setup for the Auto MPG data set.  If you wish to normalize other files you will
        // need to modify the below function calls other files.

        // First remove some columns that we will not use:
        ds.deleteColumn(8); // Car name
        ds.deleteColumn(7); // Car origin
        ds.deleteColumn(6); // Year
        ds.deleteUnknowns();

        ds.normalizeZScore(1);
        ds.normalizeZScore(2);
        ds.normalizeZScore(3);
        ds.normalizeZScore(4);
        ds.normalizeZScore(5);

        DataSet next = ds.extractSupervised(1, 4, 0, 1);
        next.shuffle();

        // Training and validation data split
        int splitTrainNum = (int) (next.numExamples() * .75);
        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed));
        DataSet trainSet = testAndTrain.getTrain();
        DataSet validationSet = testAndTrain.getTest();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainSet.asList(), trainSet.numExamples());

        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationSet.numExamples());

        // Create neural network.
        int numInputs = next.numInputs();
        int numOutputs = next.numOutcomes();
        int numHiddenNodes = 50;

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9).list(2)
                .layer(0,
                        new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                                .weightInit(WeightInit.XAVIER).activation("relu").build())
                .layer(1,
                        new OutputLayer.Builder(LossFunction.MSE).weightInit(WeightInit.XAVIER)
                                .activation("identity").nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                .epochTerminationConditions(new MaxEpochsTerminationCondition(500)) //Max of 50 epochs
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(25))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            System.out.println(features + ":Prediction(" + predicted + "):Actual(" + labels + ")");
        }

    } catch (Exception ex) {
        ex.printStackTrace();
    }
}

From source file:com.heatonresearch.aifh.examples.ann.LearnIrisBackprop.java

License:Apache License

/**
 * The main method./*from   w w  w . j a v a  2  s  .co m*/
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 0.1;
        int splitTrainNum = (int) (150 * .75);

        int numInputs = 4;
        int numOutputs = 3;
        int numHiddenNodes = 50;

        // Setup training data.
        final InputStream istream = LearnIrisBackprop.class.getResourceAsStream("/iris.csv");
        if (istream == null) {
            System.out.println("Cannot access data set, make sure the resources are available.");
            System.exit(1);
        }
        final NormalizeDataSet ds = NormalizeDataSet.load(istream);
        final CategoryMap species = ds.encodeOneOfN(4); // species is column 4
        istream.close();

        DataSet next = ds.extractSupervised(0, 4, 4, 3);
        next.shuffle();

        // Training and validation data split
        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed));
        DataSet trainSet = testAndTrain.getTrain();
        DataSet validationSet = testAndTrain.getTest();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainSet.asList(), trainSet.numExamples());

        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationSet.numExamples());

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9).list(2)
                .layer(0,
                        new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                                .weightInit(WeightInit.XAVIER).activation("relu").build())
                .layer(1,
                        new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                                .weightInit(WeightInit.XAVIER).activation("softmax").nIn(numHiddenNodes)
                                .nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                .epochTerminationConditions(new MaxEpochsTerminationCondition(500)) //Max of 50 epochs
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(25))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            System.out.println(features + ":Prediction(" + findSpecies(labels, species) + "):Actual("
                    + findSpecies(predicted, species) + ")" + predicted);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch (Exception ex) {
        ex.printStackTrace();
    }
}