Example usage for org.deeplearning4j.nn.weights WeightInit XAVIER

List of usage examples for org.deeplearning4j.nn.weights WeightInit XAVIER

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.weights WeightInit XAVIER.

Prototype

WeightInit XAVIER

To view the source code for org.deeplearning4j.nn.weights WeightInit XAVIER.

Click Source Link

Usage

From source file:aiLogicImplementation.RNNBasic.java

License:Apache License

public void create(StatsStorage statsStorage) {
    NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
    builder.iterations(1000);//from  w w w .j a v a  2  s. c o  m
    builder.learningRate(0.01);
    builder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
    builder.seed(123);
    builder.biasInit(0.2);
    builder.miniBatch(true);
    //        builder.updater(Updater.NESTEROVS);
    builder.updater(new Nesterovs(0.01));
    builder.weightInit(WeightInit.XAVIER);
    builder.regularization(true);
    builder.l2(0.001);
    //        builder.gradientNormalization(GradientNormalization.ClipL2PerParamType);
    //        builder.gradientNormalizationThreshold(0.5);

    ListBuilder listBuilder = builder.list();

    for (int i = 0; i < HIDDEN_LAYER_CONT; i++) {
        GravesLSTM.Builder hiddenLayerBuilder = new GravesLSTM.Builder();
        hiddenLayerBuilder.nIn(i == 0 ? NUMBER_OF_FEATURE_INPUT : HIDDEN_LAYER_WIDTH);
        hiddenLayerBuilder.nOut(HIDDEN_LAYER_WIDTH);
        if (i == 0) {
            hiddenLayerBuilder.activation(Activation.TANH);
        } else {
            hiddenLayerBuilder.activation(Activation.TANH);
        }
        listBuilder.layer(i, hiddenLayerBuilder.build());
    }

    //        RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.XENT);
    RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.L2);
    //        RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.MCXENT);
    //        outputLayerBuilder.activation(Activation.SIGMOID);
    outputLayerBuilder.activation(Activation.RELU);
    //        outputLayerBuilder.activation(Activation.SOFTMAX);
    outputLayerBuilder.nIn(HIDDEN_LAYER_WIDTH);
    outputLayerBuilder.nOut(NUMBER_OF_FEATURE_OUTPUT);
    outputLayerBuilder.weightInit(WeightInit.XAVIER);
    //        outputLayerBuilder.dist(new UniformDistribution(0, 1));
    listBuilder.layer(HIDDEN_LAYER_CONT, outputLayerBuilder.build());
    //        listBuilder.backpropType(BackpropType.TruncatedBPTT);
    //        listBuilder.tBPTTForwardLength(tbpttLength);
    //        listBuilder.tBPTTBackwardLength(tbpttLength);

    listBuilder.pretrain(false);
    listBuilder.backprop(true);

    // create network
    MultiLayerConfiguration conf = listBuilder.build();
    net = new MultiLayerNetwork(conf);
    net.init();

    //        net.setListeners(new ScoreIterationListener(configuration.getListenerFrequency()));
    //Then add the StatsListener to collect this information from the network, as it trains
    List list = new ArrayList();
    list.add(new StatsListener(statsStorage, configuration.getListenerFrequency()));
    list.add(new ScoreIterationListener(configuration.getListenerFrequency()));
    net.setListeners(list);

    /*
       * CREATE OUR TRAINING DATA
     */
    input = Nd4j.zeros(maze.length * maze.length * hibertMaze.getHibertMazeGUI().getAlPlayers().size()
            * hibertMaze.getHibertMazeGUI().getAlFoods().size(), NUMBER_OF_FEATURE_INPUT);
    labels = Nd4j.zeros(maze.length * maze.length * hibertMaze.getHibertMazeGUI().getAlPlayers().size()
            * hibertMaze.getHibertMazeGUI().getAlFoods().size(), NUMBER_OF_FEATURE_OUTPUT);
    positionMoving = new PositionMoving(maze);
    int counter = 0;
    for (int i = 0; i < maze.length; i++) {
        for (int j = 0; j < maze[i].length; j++) {
            int mazeFieldValue = maze[i][j];
            // input neuron for current-char is 1 at "samplePos"
            for (Player player : hibertMaze.getHibertMazeGUI().getAlPlayers()) {
                for (Food food : hibertMaze.getHibertMazeGUI().getAlFoods()) {
                    Position position = new Position(i, j);
                    input.putScalar(counter, 0, i);
                    input.putScalar(counter, 1, j);
                    input.putScalar(counter, 2, mazeFieldValue);
                    input.putScalar(counter, 3, player.getPlayerNumber());
                    input.putScalar(counter, 4, i);
                    input.putScalar(counter, 5, j);
                    input.putScalar(counter, 6, food.getFoodNumber());
                    input.putScalar(counter, 7, food.getPosition().getX());
                    input.putScalar(counter, 8, food.getPosition().getY());
                    input.putScalar(counter, 9, positionMoving.moveNorth(position) ? 1.00d : 0.00d);
                    input.putScalar(counter, 10, positionMoving.moveSouth(position) ? 1.00d : 0.00d);
                    input.putScalar(counter, 11, positionMoving.moveEast(position) ? 1.00d : 0.00d);
                    input.putScalar(counter, 12, positionMoving.moveWest(position) ? 1.00d : 0.00d);
                    input.putScalar(counter, 13,
                            player.getPosition().getX() == food.getPosition().getX()
                                    && player.getPosition().getY() == food.getPosition().getY() ? 1.00d
                                            : 0.00d);

                    labels.putScalar(counter, 0, positionMoving.moveNorth(position) ? 1.00d : 0.00d);
                    labels.putScalar(counter, 1, positionMoving.moveSouth(position) ? 1.00d : 0.00d);
                    labels.putScalar(counter, 2, positionMoving.moveEast(position) ? 1.00d : 0.00d);
                    labels.putScalar(counter, 3, positionMoving.moveWest(position) ? 1.00d : 0.00d);
                    counter = counter + 1;
                }
            }
        }
    }
    trainingData = new DataSet(input, labels);
}

From source file:cnn.image.classification.CNNImageClassification.java

public static void main(String[] args) {
    int nChannels = 3;
    int outputNum = 10;
    //        int numExamples = 80;
    int batchSize = 10;
    int nEpochs = 20;
    int iterations = 1;
    int seed = 123;
    int height = 32;
    int width = 32;
    Random randNumGen = new Random(seed);
    System.out.println("Load data....");

    File parentDir = new File("train1/");

    FileSplit filesInDir = new FileSplit(parentDir, allowedExtensions, randNumGen);

    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();

    BalancedPathFilter pathFilter = new BalancedPathFilter(randNumGen, allowedExtensions, labelMaker);

    //Split the image files into train and test. Specify the train test split as 80%,20%
    InputSplit[] filesInDirSplit = filesInDir.sample(pathFilter, 100, 0);
    InputSplit[] filesInDirSplitTest = filesInDir.sample(pathFilter, 0, 100);

    InputSplit trainData = filesInDirSplit[0];
    InputSplit testData = filesInDirSplitTest[1];

    System.out.println("train = " + trainData.length());
    System.out.println("test = " + testData.length());
    //Specifying a new record reader with the height and width you want the images to be resized to.
    //Note that the images in this example are all of different size
    //They will all be resized to the height and width specified below
    ImageRecordReader recordReader = new ImageRecordReader(height, width, nChannels, labelMaker);

    //Often there is a need to transforming images to artificially increase the size of the dataset

    recordReader.initialize(trainData);//w w w.j  a  v  a  2 s.c o  m

    DataSetIterator dataIterTrain = new RecordReaderDataSetIterator(recordReader, batchSize, 1, outputNum);
    //        recordReader.reset();
    recordReader.initialize(testData);
    DataSetIterator dataIterTest = new RecordReaderDataSetIterator(recordReader, batchSize, 1, outputNum);

    DataNormalization scaler = new ImagePreProcessingScaler(0, 1);

    dataIterTrain.setPreProcessor(scaler);
    dataIterTest.setPreProcessor(scaler);

    System.out.println("Build model....");
    MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed)
            .iterations(iterations).regularization(true).l2(0.0005)
            //                .dropOut(0.5)
            .learningRate(0.001)//.biasLearningRate(0.02)
            //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
            .weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .updater(Updater.NESTEROVS).momentum(0.9).list()
            .layer(0,
                    new ConvolutionLayer.Builder(5, 5).nIn(nChannels).stride(1, 1).nOut(20)
                            .activation("identity").build())
            .layer(1,
                    new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2)
                            .build())
            .layer(2, new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50).activation("identity").build())
            .layer(3,
                    new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                            .kernelSize(2, 2).stride(2, 2).build())
            .layer(4, new DenseLayer.Builder().activation("relu").nOut(500).build())
            .layer(5,
                    new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum)
                            .activation("softmax").build())
            .setInputType(InputType.convolutional(height, width, nChannels)) //See note below
            .backprop(true).pretrain(false);

    MultiLayerConfiguration b = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations)
            .regularization(false).l2(0.005) // tried 0.0001, 0.0005
            .learningRate(0.0001) // tried 0.00001, 0.00005, 0.000001
            .weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .updater(Updater.NESTEROVS).momentum(0.9).list().layer(0, new ConvolutionLayer.Builder(5, 5)
                    //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
                    .nIn(nChannels).stride(1, 1).nOut(50) // tried 10, 20, 40, 50
                    .activation("relu").build())
            .layer(1,
                    new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                            .kernelSize(2, 2).stride(2, 2).build())
            .layer(2, new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(100) // tried 25, 50, 100
                    .activation("relu").build())
            .layer(3,
                    new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                            .kernelSize(2, 2).stride(2, 2).build())
            .layer(4, new DenseLayer.Builder().activation("relu").nOut(500).build())
            .layer(5,
                    new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum)
                            .activation("softmax").build())
            .backprop(true).pretrain(false).cnnInputSize(height, width, nChannels).build();

    MultiLayerConfiguration conf = builder.build();
    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();

    System.out.println("Train model....");
    model.setListeners(new ScoreIterationListener(1));
    //        for( int i=0; i<nEpochs; i++ ) {
    //            model.setListeners(new HistogramIterationListener(1));

    MultipleEpochsIterator trainIter = new MultipleEpochsIterator(nEpochs, dataIterTrain, 2);
    model.fit(trainIter);

    //            System.out.println("*** Completed epoch - " + i + "  ***");

    System.out.println("Evaluate model....");
    //            Evaluation eval = new Evaluation(outputNum);
    //            while(dataIterTest.hasNext()){
    //                DataSet ds = dataIterTest.next();
    //                INDArray output = model.output(ds.getFeatureMatrix(), false);
    //                eval.eval(ds.getLabels(), output);
    //            }
    //            System.out.println(eval.stats());
    //            dataIterTest.reset();
    //        }

    Evaluation eval1 = model.evaluate(dataIterTest);
    System.out.println(eval1.stats());

    System.out.println("****************Example finished********************");
}

From source file:com.circle_technologies.cnn4j.predictive.provider.network.SimpleHiddenLayerNetworkProvider.java

License:Apache License

@Override
public MultiLayerNetwork provideNetwork(int inputs, int outputs) {
    NeuralNetConfiguration.ListBuilder listBuilder = new NeuralNetConfiguration.Builder().iterations(1)
            .weightInit(WeightInit.XAVIER).activation(getActivationFunction())
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(getLearningRate())
            .list().layer(0, new DenseLayer.Builder().nIn(inputs).nOut(inputs).activation("relu").build());

    for (int i = 0; i < mHiddenLayerCount; i++) {
        listBuilder = listBuilder.layer(i + 1,
                new DenseLayer.Builder().nIn(inputs).nOut(inputs).activation(getActivationFunction()).build());
    }/*from  w w w  .j  a v  a 2 s.  c  o  m*/

    MultiLayerConfiguration configuration = listBuilder
            .layer(mHiddenLayerCount + 1, new OutputLayer.Builder().nIn(inputs).nOut(outputs).build())
            .backprop(true).pretrain(false).build();

    return new MultiLayerNetwork(configuration);
}

From source file:com.circle_technologies.cnn4j.predictive.provider.network.SimpleSingleLayerNetworkProvider.java

License:Apache License

@Override
public MultiLayerNetwork provideNetwork(int inputs, int outputs) {
    MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder().iterations(1)
            .weightInit(WeightInit.XAVIER).activation(getActivationFunction())
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(getLearningRate())
            .list()/*from   w ww .  ja va  2s .  co  m*/
            .layer(0,
                    new DenseLayer.Builder().nIn(inputs).nOut(inputs).activation(getActivationFunction())
                            .build())
            .layer(1, new OutputLayer.Builder().nIn(inputs).nOut(outputs).build()).backprop(true)
            .pretrain(false).build();

    return new MultiLayerNetwork(configuration);
}

From source file:com.example.android.displayingbitmaps.ui.ImageGridActivity.java

License:Apache License

public void trainMLP() throws Exception {
    Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
    final int numRows = 28;
    final int numColumns = 28;
    int outputNum = 10;
    int numSamples = 10000;
    int batchSize = 500;
    int iterations = 10;
    int seed = 123;
    int listenerFreq = iterations / 5;
    int splitTrainNum = (int) (batchSize * .8);
    DataSet mnist;//from w ww  .j a va  2 s .c om
    SplitTestAndTrain trainTest;
    DataSet trainInput;
    List<INDArray> testInput = new ArrayList<>();
    List<INDArray> testLabels = new ArrayList<>();

    log.info("Load data....");
    DataSetIterator mnistIter = new MnistDataSetIterator(batchSize, numSamples, true);

    log.info("Build model....");
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(iterations)
            .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).learningRate(1e-1f)
            .momentum(0.5).momentumAfter(Collections.singletonMap(3, 0.9)).useDropConnect(true).list(2)
            .layer(0,
                    new DenseLayer.Builder().nIn(numRows * numColumns).nOut(1000).activation("relu")
                            .weightInit(WeightInit.XAVIER).build())
            .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).nIn(1000).nOut(outputNum)
                    .activation("softmax").weightInit(WeightInit.XAVIER).build())
            .build();

    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();
    model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq)));

    log.info("Train model....");
    model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq)));
    while (mnistIter.hasNext()) {
        mnist = mnistIter.next();
        trainTest = mnist.splitTestAndTrain(splitTrainNum, new Random(seed)); // train set that is the result
        trainInput = trainTest.getTrain(); // get feature matrix and labels for training
        testInput.add(trainTest.getTest().getFeatureMatrix());
        testLabels.add(trainTest.getTest().getLabels());
        model.fit(trainInput);
    }

    log.info("Evaluate model....");
    Evaluation eval = new Evaluation(outputNum);
    for (int i = 0; i < testInput.size(); i++) {
        INDArray output = model.output(testInput.get(i));
        eval.eval(testLabels.get(i), output);
    }

    log.info(eval.stats());
    log.info("****************Example finished********************");
}

From source file:com.heatonresearch.aifh.examples.ann.LearnAutoMPGBackprop.java

License:Apache License

/**
 * The main method.//from w w  w  . ja v  a2 s. c om
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 0.01;

        // Setup training data.
        final InputStream istream = LearnAutoMPGBackprop.class.getResourceAsStream("/auto-mpg.data.csv");
        if (istream == null) {
            System.out.println("Cannot access data set, make sure the resources are available.");
            System.exit(1);
        }
        final NormalizeDataSet ds = NormalizeDataSet.load(istream);
        istream.close();

        // The following ranges are setup for the Auto MPG data set.  If you wish to normalize other files you will
        // need to modify the below function calls other files.

        // First remove some columns that we will not use:
        ds.deleteColumn(8); // Car name
        ds.deleteColumn(7); // Car origin
        ds.deleteColumn(6); // Year
        ds.deleteUnknowns();

        ds.normalizeZScore(1);
        ds.normalizeZScore(2);
        ds.normalizeZScore(3);
        ds.normalizeZScore(4);
        ds.normalizeZScore(5);

        DataSet next = ds.extractSupervised(1, 4, 0, 1);
        next.shuffle();

        // Training and validation data split
        int splitTrainNum = (int) (next.numExamples() * .75);
        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed));
        DataSet trainSet = testAndTrain.getTrain();
        DataSet validationSet = testAndTrain.getTest();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainSet.asList(), trainSet.numExamples());

        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationSet.numExamples());

        // Create neural network.
        int numInputs = next.numInputs();
        int numOutputs = next.numOutcomes();
        int numHiddenNodes = 50;

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9).list(2)
                .layer(0,
                        new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                                .weightInit(WeightInit.XAVIER).activation("relu").build())
                .layer(1,
                        new OutputLayer.Builder(LossFunction.MSE).weightInit(WeightInit.XAVIER)
                                .activation("identity").nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                .epochTerminationConditions(new MaxEpochsTerminationCondition(500)) //Max of 50 epochs
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(25))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            System.out.println(features + ":Prediction(" + predicted + "):Actual(" + labels + ")");
        }

    } catch (Exception ex) {
        ex.printStackTrace();
    }
}

From source file:com.heatonresearch.aifh.examples.ann.LearnDigitsBackprop.java

License:Apache License

/**
 * The main method.//w  w  w  .ja  v a  2 s .  co  m
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        System.out.println(trainingSet.get(0).getFeatures().size(1));
        System.out.println(validationSet.get(0).getFeatures().size(1));

        int numInputs = trainingReader.getNumCols() * trainingReader.getNumRows();
        int numOutputs = 10;
        int numHiddenNodes = 200;

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9).regularization(true).dropOut(0.50).list(2)
                .layer(0,
                        new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                                .weightInit(WeightInit.XAVIER).activation("relu").build())
                .layer(1,
                        new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                                .weightInit(WeightInit.XAVIER).activation("softmax").nIn(numHiddenNodes)
                                .nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch (Exception ex) {
        ex.printStackTrace();
    }

}

From source file:com.heatonresearch.aifh.examples.ann.LearnDigitsConv.java

License:Apache License

/**
 * The main method.// ww w .  j a va  2s.  c om
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;
        int channels = 1;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        int numOutputs = 10;

        // Create neural network.
        MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .regularization(true).l2(0.0005).learningRate(0.01).weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS)
                .momentum(0.9).list(4)
                .layer(0,
                        new ConvolutionLayer.Builder(5, 5).nIn(channels).stride(1, 1).nOut(20).dropOut(0.5)
                                .activation("relu").build())
                .layer(1,
                        new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
                                .stride(2, 2).build())
                .layer(2, new DenseLayer.Builder().activation("relu").nOut(500).build())
                .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10)
                        .activation("softmax").build())
                .backprop(true).pretrain(false);

        new ConvolutionLayerSetup(builder, 28, 28, 1);
        MultiLayerConfiguration conf = builder.build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch (Exception ex) {
        ex.printStackTrace();
    }

}

From source file:com.heatonresearch.aifh.examples.ann.LearnDigitsDropout.java

License:Apache License

/**
 * The main method.//  www.j  a  v a 2s. co  m
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        System.out.println(trainingSet.get(0).getFeatures().size(1));
        System.out.println(validationSet.get(0).getFeatures().size(1));

        int numInputs = trainingReader.getNumCols() * trainingReader.getNumRows();
        int numOutputs = 10;
        int numHiddenNodes = 100;

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9).list(2)
                .layer(0,
                        new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                                .weightInit(WeightInit.XAVIER).activation("relu").build())
                .layer(1,
                        new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                                .weightInit(WeightInit.XAVIER).activation("softmax").nIn(numHiddenNodes)
                                .nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch (Exception ex) {
        ex.printStackTrace();
    }

}

From source file:com.heatonresearch.aifh.examples.ann.LearnIrisBackprop.java

License:Apache License

/**
 * The main method.//w  w  w  .  j  a  v a  2  s  .co  m
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 0.1;
        int splitTrainNum = (int) (150 * .75);

        int numInputs = 4;
        int numOutputs = 3;
        int numHiddenNodes = 50;

        // Setup training data.
        final InputStream istream = LearnIrisBackprop.class.getResourceAsStream("/iris.csv");
        if (istream == null) {
            System.out.println("Cannot access data set, make sure the resources are available.");
            System.exit(1);
        }
        final NormalizeDataSet ds = NormalizeDataSet.load(istream);
        final CategoryMap species = ds.encodeOneOfN(4); // species is column 4
        istream.close();

        DataSet next = ds.extractSupervised(0, 4, 4, 3);
        next.shuffle();

        // Training and validation data split
        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed));
        DataSet trainSet = testAndTrain.getTrain();
        DataSet validationSet = testAndTrain.getTest();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainSet.asList(), trainSet.numExamples());

        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(),
                validationSet.numExamples());

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9).list(2)
                .layer(0,
                        new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                                .weightInit(WeightInit.XAVIER).activation("relu").build())
                .layer(1,
                        new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                                .weightInit(WeightInit.XAVIER).activation("softmax").nIn(numHiddenNodes)
                                .nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                .epochTerminationConditions(new MaxEpochsTerminationCondition(500)) //Max of 50 epochs
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(25))
                .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
                .modelSaver(saver).build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            System.out.println(features + ":Prediction(" + findSpecies(labels, species) + "):Actual("
                    + findSpecies(predicted, species) + ")" + predicted);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch (Exception ex) {
        ex.printStackTrace();
    }
}