Example usage for org.deeplearning4j.nn.api OptimizationAlgorithm STOCHASTIC_GRADIENT_DESCENT

List of usage examples for org.deeplearning4j.nn.api OptimizationAlgorithm STOCHASTIC_GRADIENT_DESCENT

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.api OptimizationAlgorithm STOCHASTIC_GRADIENT_DESCENT.

Prototype

OptimizationAlgorithm STOCHASTIC_GRADIENT_DESCENT

To view the source code for org.deeplearning4j.nn.api OptimizationAlgorithm STOCHASTIC_GRADIENT_DESCENT.

Click Source Link

Usage

From source file:aiLogicImplementation.RNNBasic.java

License:Apache License

public void create(StatsStorage statsStorage) {
    NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
    builder.iterations(1000);//from  w  w w . ja  va 2 s  . c om
    builder.learningRate(0.01);
    builder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
    builder.seed(123);
    builder.biasInit(0.2);
    builder.miniBatch(true);
    //        builder.updater(Updater.NESTEROVS);
    builder.updater(new Nesterovs(0.01));
    builder.weightInit(WeightInit.XAVIER);
    builder.regularization(true);
    builder.l2(0.001);
    //        builder.gradientNormalization(GradientNormalization.ClipL2PerParamType);
    //        builder.gradientNormalizationThreshold(0.5);

    ListBuilder listBuilder = builder.list();

    for (int i = 0; i < HIDDEN_LAYER_CONT; i++) {
        GravesLSTM.Builder hiddenLayerBuilder = new GravesLSTM.Builder();
        hiddenLayerBuilder.nIn(i == 0 ? NUMBER_OF_FEATURE_INPUT : HIDDEN_LAYER_WIDTH);
        hiddenLayerBuilder.nOut(HIDDEN_LAYER_WIDTH);
        if (i == 0) {
            hiddenLayerBuilder.activation(Activation.TANH);
        } else {
            hiddenLayerBuilder.activation(Activation.TANH);
        }
        listBuilder.layer(i, hiddenLayerBuilder.build());
    }

    //        RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.XENT);
    RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.L2);
    //        RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.MCXENT);
    //        outputLayerBuilder.activation(Activation.SIGMOID);
    outputLayerBuilder.activation(Activation.RELU);
    //        outputLayerBuilder.activation(Activation.SOFTMAX);
    outputLayerBuilder.nIn(HIDDEN_LAYER_WIDTH);
    outputLayerBuilder.nOut(NUMBER_OF_FEATURE_OUTPUT);
    outputLayerBuilder.weightInit(WeightInit.XAVIER);
    //        outputLayerBuilder.dist(new UniformDistribution(0, 1));
    listBuilder.layer(HIDDEN_LAYER_CONT, outputLayerBuilder.build());
    //        listBuilder.backpropType(BackpropType.TruncatedBPTT);
    //        listBuilder.tBPTTForwardLength(tbpttLength);
    //        listBuilder.tBPTTBackwardLength(tbpttLength);

    listBuilder.pretrain(false);
    listBuilder.backprop(true);

    // create network
    MultiLayerConfiguration conf = listBuilder.build();
    net = new MultiLayerNetwork(conf);
    net.init();

    //        net.setListeners(new ScoreIterationListener(configuration.getListenerFrequency()));
    //Then add the StatsListener to collect this information from the network, as it trains
    List list = new ArrayList();
    list.add(new StatsListener(statsStorage, configuration.getListenerFrequency()));
    list.add(new ScoreIterationListener(configuration.getListenerFrequency()));
    net.setListeners(list);

    /*
       * CREATE OUR TRAINING DATA
     */
    input = Nd4j.zeros(maze.length * maze.length * hibertMaze.getHibertMazeGUI().getAlPlayers().size()
            * hibertMaze.getHibertMazeGUI().getAlFoods().size(), NUMBER_OF_FEATURE_INPUT);
    labels = Nd4j.zeros(maze.length * maze.length * hibertMaze.getHibertMazeGUI().getAlPlayers().size()
            * hibertMaze.getHibertMazeGUI().getAlFoods().size(), NUMBER_OF_FEATURE_OUTPUT);
    positionMoving = new PositionMoving(maze);
    int counter = 0;
    for (int i = 0; i < maze.length; i++) {
        for (int j = 0; j < maze[i].length; j++) {
            int mazeFieldValue = maze[i][j];
            // input neuron for current-char is 1 at "samplePos"
            for (Player player : hibertMaze.getHibertMazeGUI().getAlPlayers()) {
                for (Food food : hibertMaze.getHibertMazeGUI().getAlFoods()) {
                    Position position = new Position(i, j);
                    input.putScalar(counter, 0, i);
                    input.putScalar(counter, 1, j);
                    input.putScalar(counter, 2, mazeFieldValue);
                    input.putScalar(counter, 3, player.getPlayerNumber());
                    input.putScalar(counter, 4, i);
                    input.putScalar(counter, 5, j);
                    input.putScalar(counter, 6, food.getFoodNumber());
                    input.putScalar(counter, 7, food.getPosition().getX());
                    input.putScalar(counter, 8, food.getPosition().getY());
                    input.putScalar(counter, 9, positionMoving.moveNorth(position) ? 1.00d : 0.00d);
                    input.putScalar(counter, 10, positionMoving.moveSouth(position) ? 1.00d : 0.00d);
                    input.putScalar(counter, 11, positionMoving.moveEast(position) ? 1.00d : 0.00d);
                    input.putScalar(counter, 12, positionMoving.moveWest(position) ? 1.00d : 0.00d);
                    input.putScalar(counter, 13,
                            player.getPosition().getX() == food.getPosition().getX()
                                    && player.getPosition().getY() == food.getPosition().getY() ? 1.00d
                                            : 0.00d);

                    labels.putScalar(counter, 0, positionMoving.moveNorth(position) ? 1.00d : 0.00d);
                    labels.putScalar(counter, 1, positionMoving.moveSouth(position) ? 1.00d : 0.00d);
                    labels.putScalar(counter, 2, positionMoving.moveEast(position) ? 1.00d : 0.00d);
                    labels.putScalar(counter, 3, positionMoving.moveWest(position) ? 1.00d : 0.00d);
                    counter = counter + 1;
                }
            }
        }
    }
    trainingData = new DataSet(input, labels);
}

From source file:cnn.image.classification.CNNImageClassification.java

public static void main(String[] args) {
    int nChannels = 3;
    int outputNum = 10;
    //        int numExamples = 80;
    int batchSize = 10;
    int nEpochs = 20;
    int iterations = 1;
    int seed = 123;
    int height = 32;
    int width = 32;
    Random randNumGen = new Random(seed);
    System.out.println("Load data....");

    File parentDir = new File("train1/");

    FileSplit filesInDir = new FileSplit(parentDir, allowedExtensions, randNumGen);

    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();

    BalancedPathFilter pathFilter = new BalancedPathFilter(randNumGen, allowedExtensions, labelMaker);

    //Split the image files into train and test. Specify the train test split as 80%,20%
    InputSplit[] filesInDirSplit = filesInDir.sample(pathFilter, 100, 0);
    InputSplit[] filesInDirSplitTest = filesInDir.sample(pathFilter, 0, 100);

    InputSplit trainData = filesInDirSplit[0];
    InputSplit testData = filesInDirSplitTest[1];

    System.out.println("train = " + trainData.length());
    System.out.println("test = " + testData.length());
    //Specifying a new record reader with the height and width you want the images to be resized to.
    //Note that the images in this example are all of different size
    //They will all be resized to the height and width specified below
    ImageRecordReader recordReader = new ImageRecordReader(height, width, nChannels, labelMaker);

    //Often there is a need to transforming images to artificially increase the size of the dataset

    recordReader.initialize(trainData);/* www .j ava  2 s.  co  m*/

    DataSetIterator dataIterTrain = new RecordReaderDataSetIterator(recordReader, batchSize, 1, outputNum);
    //        recordReader.reset();
    recordReader.initialize(testData);
    DataSetIterator dataIterTest = new RecordReaderDataSetIterator(recordReader, batchSize, 1, outputNum);

    DataNormalization scaler = new ImagePreProcessingScaler(0, 1);

    dataIterTrain.setPreProcessor(scaler);
    dataIterTest.setPreProcessor(scaler);

    System.out.println("Build model....");
    MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed)
            .iterations(iterations).regularization(true).l2(0.0005)
            //                .dropOut(0.5)
            .learningRate(0.001)//.biasLearningRate(0.02)
            //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
            .weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .updater(Updater.NESTEROVS).momentum(0.9).list()
            .layer(0,
                    new ConvolutionLayer.Builder(5, 5).nIn(nChannels).stride(1, 1).nOut(20)
                            .activation("identity").build())
            .layer(1,
                    new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2)
                            .build())
            .layer(2, new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50).activation("identity").build())
            .layer(3,
                    new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                            .kernelSize(2, 2).stride(2, 2).build())
            .layer(4, new DenseLayer.Builder().activation("relu").nOut(500).build())
            .layer(5,
                    new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum)
                            .activation("softmax").build())
            .setInputType(InputType.convolutional(height, width, nChannels)) //See note below
            .backprop(true).pretrain(false);

    MultiLayerConfiguration b = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations)
            .regularization(false).l2(0.005) // tried 0.0001, 0.0005
            .learningRate(0.0001) // tried 0.00001, 0.00005, 0.000001
            .weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .updater(Updater.NESTEROVS).momentum(0.9).list().layer(0, new ConvolutionLayer.Builder(5, 5)
                    //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
                    .nIn(nChannels).stride(1, 1).nOut(50) // tried 10, 20, 40, 50
                    .activation("relu").build())
            .layer(1,
                    new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                            .kernelSize(2, 2).stride(2, 2).build())
            .layer(2, new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(100) // tried 25, 50, 100
                    .activation("relu").build())
            .layer(3,
                    new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                            .kernelSize(2, 2).stride(2, 2).build())
            .layer(4, new DenseLayer.Builder().activation("relu").nOut(500).build())
            .layer(5,
                    new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum)
                            .activation("softmax").build())
            .backprop(true).pretrain(false).cnnInputSize(height, width, nChannels).build();

    MultiLayerConfiguration conf = builder.build();
    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();

    System.out.println("Train model....");
    model.setListeners(new ScoreIterationListener(1));
    //        for( int i=0; i<nEpochs; i++ ) {
    //            model.setListeners(new HistogramIterationListener(1));

    MultipleEpochsIterator trainIter = new MultipleEpochsIterator(nEpochs, dataIterTrain, 2);
    model.fit(trainIter);

    //            System.out.println("*** Completed epoch - " + i + "  ***");

    System.out.println("Evaluate model....");
    //            Evaluation eval = new Evaluation(outputNum);
    //            while(dataIterTest.hasNext()){
    //                DataSet ds = dataIterTest.next();
    //                INDArray output = model.output(ds.getFeatureMatrix(), false);
    //                eval.eval(ds.getLabels(), output);
    //            }
    //            System.out.println(eval.stats());
    //            dataIterTest.reset();
    //        }

    Evaluation eval1 = model.evaluate(dataIterTest);
    System.out.println(eval1.stats());

    System.out.println("****************Example finished********************");
}

From source file:com.circle_technologies.cnn4j.predictive.provider.network.SimpleHiddenLayerNetworkProvider.java

License:Apache License

@Override
public MultiLayerNetwork provideNetwork(int inputs, int outputs) {
    NeuralNetConfiguration.ListBuilder listBuilder = new NeuralNetConfiguration.Builder().iterations(1)
            .weightInit(WeightInit.XAVIER).activation(getActivationFunction())
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(getLearningRate())
            .list().layer(0, new DenseLayer.Builder().nIn(inputs).nOut(inputs).activation("relu").build());

    for (int i = 0; i < mHiddenLayerCount; i++) {
        listBuilder = listBuilder.layer(i + 1,
                new DenseLayer.Builder().nIn(inputs).nOut(inputs).activation(getActivationFunction()).build());
    }// ww w . j  av a  2 s  .c  o m

    MultiLayerConfiguration configuration = listBuilder
            .layer(mHiddenLayerCount + 1, new OutputLayer.Builder().nIn(inputs).nOut(outputs).build())
            .backprop(true).pretrain(false).build();

    return new MultiLayerNetwork(configuration);
}

From source file:com.circle_technologies.cnn4j.predictive.provider.network.SimpleSingleLayerNetworkProvider.java

License:Apache License

@Override
public MultiLayerNetwork provideNetwork(int inputs, int outputs) {
    MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder().iterations(1)
            .weightInit(WeightInit.XAVIER).activation(getActivationFunction())
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(getLearningRate())
            .list()//ww w .  j a  v a  2 s  . c  o  m
            .layer(0,
                    new DenseLayer.Builder().nIn(inputs).nOut(inputs).activation(getActivationFunction())
                            .build())
            .layer(1, new OutputLayer.Builder().nIn(inputs).nOut(outputs).build()).backprop(true)
            .pretrain(false).build();

    return new MultiLayerNetwork(configuration);
}

From source file:com.example.android.displayingbitmaps.ui.ImageGridActivity.java

License:Apache License

public void trainMLP() throws Exception {
    Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
    final int numRows = 28;
    final int numColumns = 28;
    int outputNum = 10;
    int numSamples = 10000;
    int batchSize = 500;
    int iterations = 10;
    int seed = 123;
    int listenerFreq = iterations / 5;
    int splitTrainNum = (int) (batchSize * .8);
    DataSet mnist;/*from  w  w  w.  j ava 2 s .c  o  m*/
    SplitTestAndTrain trainTest;
    DataSet trainInput;
    List<INDArray> testInput = new ArrayList<>();
    List<INDArray> testLabels = new ArrayList<>();

    log.info("Load data....");
    DataSetIterator mnistIter = new MnistDataSetIterator(batchSize, numSamples, true);

    log.info("Build model....");
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(iterations)
            .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).learningRate(1e-1f)
            .momentum(0.5).momentumAfter(Collections.singletonMap(3, 0.9)).useDropConnect(true).list(2)
            .layer(0,
                    new DenseLayer.Builder().nIn(numRows * numColumns).nOut(1000).activation("relu")
                            .weightInit(WeightInit.XAVIER).build())
            .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).nIn(1000).nOut(outputNum)
                    .activation("softmax").weightInit(WeightInit.XAVIER).build())
            .build();

    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();
    model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq)));

    log.info("Train model....");
    model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq)));
    while (mnistIter.hasNext()) {
        mnist = mnistIter.next();
        trainTest = mnist.splitTestAndTrain(splitTrainNum, new Random(seed)); // train set that is the result
        trainInput = trainTest.getTrain(); // get feature matrix and labels for training
        testInput.add(trainTest.getTest().getFeatureMatrix());
        testLabels.add(trainTest.getTest().getLabels());
        model.fit(trainInput);
    }

    log.info("Evaluate model....");
    Evaluation eval = new Evaluation(outputNum);
    for (int i = 0; i < testInput.size(); i++) {
        INDArray output = model.output(testInput.get(i));
        eval.eval(testLabels.get(i), output);
    }

    log.info(eval.stats());
    log.info("****************Example finished********************");
}

From source file:com.heatonresearch.aifh.examples.ann.LearnAutoMPGBackprop.java

License:Apache License

/**
 * The main method.//  w  w w  .ja v  a 2s. co  m
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 0.01;

        // Setup training data.
        final InputStream istream = LearnAutoMPGBackprop.class.getResourceAsStream("/auto-mpg.data.csv");
        if (istream == null) {
            System.out.println("Cannot access data set, make sure the resources are available.");
            System.exit(1);
        }
        final NormalizeDataSet ds = NormalizeDataSet.load(istream);
        istream.close();

        // The following ranges are setup for the Auto MPG data set.  If you wish to normalize other files you will
        // need to modify the below function calls other files.

        // First remove some columns that we will not use:
        ds.deleteColumn(8); // Car name
        ds.deleteColumn(7); // Car origin
        ds.deleteColumn(6); // Year
        ds.deleteUnknowns();

        ds.normalizeZScore(1);
        ds.normalizeZScore(2);
        ds.normalizeZScore(3);
        ds.normalizeZScore(4);
        ds.normalizeZScore(5);

        DataSet next = ds.extractSupervised(1, 4, 0, 1);
        next.shuffle();

        // Training and validation data split
        int splitTrainNum = (int) (next.numExamples() * .75);
        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed));
        DataSet trainSet = testAndTrain.getTrain();
        DataSet validationSet = testAndTrain.getTest();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainSet.asList(), trainSet.numExamples());

        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationSet.numExamples());

        // Create neural network.
        int numInputs = next.numInputs();
        int numOutputs = next.numOutcomes();
        int numHiddenNodes = 50;

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9).list(2)
                .layer(0,
                        new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                                .weightInit(WeightInit.XAVIER).activation("relu").build())
                .layer(1,
                        new OutputLayer.Builder(LossFunction.MSE).weightInit(WeightInit.XAVIER)
                                .activation("identity").nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                .epochTerminationConditions(new MaxEpochsTerminationCondition(500)) //Max of 50 epochs
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(25))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            System.out.println(features + ":Prediction(" + predicted + "):Actual(" + labels + ")");
        }

    } catch (Exception ex) {
        ex.printStackTrace();
    }
}

From source file:com.heatonresearch.aifh.examples.ann.LearnDigitsBackprop.java

License:Apache License

/**
 * The main method./* w w  w  .j a  va  2 s. c om*/
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        System.out.println(trainingSet.get(0).getFeatures().size(1));
        System.out.println(validationSet.get(0).getFeatures().size(1));

        int numInputs = trainingReader.getNumCols() * trainingReader.getNumRows();
        int numOutputs = 10;
        int numHiddenNodes = 200;

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9).regularization(true).dropOut(0.50).list(2)
                .layer(0,
                        new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                                .weightInit(WeightInit.XAVIER).activation("relu").build())
                .layer(1,
                        new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                                .weightInit(WeightInit.XAVIER).activation("softmax").nIn(numHiddenNodes)
                                .nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch (Exception ex) {
        ex.printStackTrace();
    }

}

From source file:com.heatonresearch.aifh.examples.ann.LearnDigitsConv.java

License:Apache License

/**
 * The main method./* w ww  .j  a v  a 2 s .c om*/
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;
        int channels = 1;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        int numOutputs = 10;

        // Create neural network.
        MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .regularization(true).l2(0.0005).learningRate(0.01).weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS)
                .momentum(0.9).list(4)
                .layer(0,
                        new ConvolutionLayer.Builder(5, 5).nIn(channels).stride(1, 1).nOut(20).dropOut(0.5)
                                .activation("relu").build())
                .layer(1,
                        new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
                                .stride(2, 2).build())
                .layer(2, new DenseLayer.Builder().activation("relu").nOut(500).build())
                .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10)
                        .activation("softmax").build())
                .backprop(true).pretrain(false);

        new ConvolutionLayerSetup(builder, 28, 28, 1);
        MultiLayerConfiguration conf = builder.build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch (Exception ex) {
        ex.printStackTrace();
    }

}

From source file:com.heatonresearch.aifh.examples.ann.LearnDigitsDropout.java

License:Apache License

/**
 * The main method.//  w  w w.  j av a 2 s. c o  m
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        System.out.println(trainingSet.get(0).getFeatures().size(1));
        System.out.println(validationSet.get(0).getFeatures().size(1));

        int numInputs = trainingReader.getNumCols() * trainingReader.getNumRows();
        int numOutputs = 10;
        int numHiddenNodes = 100;

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9).list(2)
                .layer(0,
                        new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                                .weightInit(WeightInit.XAVIER).activation("relu").build())
                .layer(1,
                        new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                                .weightInit(WeightInit.XAVIER).activation("softmax").nIn(numHiddenNodes)
                                .nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch (Exception ex) {
        ex.printStackTrace();
    }

}

From source file:com.heatonresearch.aifh.examples.ann.LearnIrisBackprop.java

License:Apache License

/**
 * The main method./* w  w w. j  av  a 2  s  .  co  m*/
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 0.1;
        int splitTrainNum = (int) (150 * .75);

        int numInputs = 4;
        int numOutputs = 3;
        int numHiddenNodes = 50;

        // Setup training data.
        final InputStream istream = LearnIrisBackprop.class.getResourceAsStream("/iris.csv");
        if (istream == null) {
            System.out.println("Cannot access data set, make sure the resources are available.");
            System.exit(1);
        }
        final NormalizeDataSet ds = NormalizeDataSet.load(istream);
        final CategoryMap species = ds.encodeOneOfN(4); // species is column 4
        istream.close();

        DataSet next = ds.extractSupervised(0, 4, 4, 3);
        next.shuffle();

        // Training and validation data split
        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed));
        DataSet trainSet = testAndTrain.getTrain();
        DataSet validationSet = testAndTrain.getTest();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainSet.asList(), trainSet.numExamples());

        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationSet.numExamples());

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9).list(2)
                .layer(0,
                        new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                                .weightInit(WeightInit.XAVIER).activation("relu").build())
                .layer(1,
                        new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                                .weightInit(WeightInit.XAVIER).activation("softmax").nIn(numHiddenNodes)
                                .nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                .epochTerminationConditions(new MaxEpochsTerminationCondition(500)) //Max of 50 epochs
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(25))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            System.out.println(features + ":Prediction(" + findSpecies(labels, species) + "):Actual("
                    + findSpecies(predicted, species) + ")" + predicted);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch (Exception ex) {
        ex.printStackTrace();
    }
}