Example usage for org.deeplearning4j.datasets.iterator.impl MnistDataSetIterator MnistDataSetIterator

List of usage examples for org.deeplearning4j.datasets.iterator.impl MnistDataSetIterator MnistDataSetIterator

Introduction

In this page you can find the example usage for org.deeplearning4j.datasets.iterator.impl MnistDataSetIterator MnistDataSetIterator.

Prototype

public MnistDataSetIterator(int batchSize, boolean train, int seed) throws IOException 

Source Link

Document

Constructor to get the full MNIST data set (either test or train sets) without binarization (i.e., just normalization into range of 0 to 1), with shuffling based on a random seed.

Usage

From source file:com.example.android.displayingbitmaps.ui.ImageGridActivity.java

License:Apache License

public void trainMLP() throws Exception {
    Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
    final int numRows = 28;
    final int numColumns = 28;
    int outputNum = 10;
    int numSamples = 10000;
    int batchSize = 500;
    int iterations = 10;
    int seed = 123;
    int listenerFreq = iterations / 5;
    int splitTrainNum = (int) (batchSize * .8);
    DataSet mnist;/* w  w w .  j a v  a 2s .  c  o  m*/
    SplitTestAndTrain trainTest;
    DataSet trainInput;
    List<INDArray> testInput = new ArrayList<>();
    List<INDArray> testLabels = new ArrayList<>();

    log.info("Load data....");
    DataSetIterator mnistIter = new MnistDataSetIterator(batchSize, numSamples, true);

    log.info("Build model....");
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(iterations)
            .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).learningRate(1e-1f)
            .momentum(0.5).momentumAfter(Collections.singletonMap(3, 0.9)).useDropConnect(true).list(2)
            .layer(0,
                    new DenseLayer.Builder().nIn(numRows * numColumns).nOut(1000).activation("relu")
                            .weightInit(WeightInit.XAVIER).build())
            .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).nIn(1000).nOut(outputNum)
                    .activation("softmax").weightInit(WeightInit.XAVIER).build())
            .build();

    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();
    model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq)));

    log.info("Train model....");
    model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq)));
    while (mnistIter.hasNext()) {
        mnist = mnistIter.next();
        trainTest = mnist.splitTestAndTrain(splitTrainNum, new Random(seed)); // train set that is the result
        trainInput = trainTest.getTrain(); // get feature matrix and labels for training
        testInput.add(trainTest.getTest().getFeatureMatrix());
        testLabels.add(trainTest.getTest().getLabels());
        model.fit(trainInput);
    }

    log.info("Evaluate model....");
    Evaluation eval = new Evaluation(outputNum);
    for (int i = 0; i < testInput.size(); i++) {
        INDArray output = model.output(testInput.get(i));
        eval.eval(testLabels.get(i), output);
    }

    log.info(eval.stats());
    log.info("****************Example finished********************");
}