List of usage examples for org.deeplearning4j.datasets.iterator.impl MnistDataSetIterator MnistDataSetIterator
public MnistDataSetIterator(int batchSize, boolean train, int seed) throws IOException
From source file:com.example.android.displayingbitmaps.ui.ImageGridActivity.java
License:Apache License
public void trainMLP() throws Exception { Nd4j.ENFORCE_NUMERICAL_STABILITY = true; final int numRows = 28; final int numColumns = 28; int outputNum = 10; int numSamples = 10000; int batchSize = 500; int iterations = 10; int seed = 123; int listenerFreq = iterations / 5; int splitTrainNum = (int) (batchSize * .8); DataSet mnist;/* w w w . j a v a 2s . c o m*/ SplitTestAndTrain trainTest; DataSet trainInput; List<INDArray> testInput = new ArrayList<>(); List<INDArray> testLabels = new ArrayList<>(); log.info("Load data...."); DataSetIterator mnistIter = new MnistDataSetIterator(batchSize, numSamples, true); log.info("Build model...."); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(iterations) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).learningRate(1e-1f) .momentum(0.5).momentumAfter(Collections.singletonMap(3, 0.9)).useDropConnect(true).list(2) .layer(0, new DenseLayer.Builder().nIn(numRows * numColumns).nOut(1000).activation("relu") .weightInit(WeightInit.XAVIER).build()) .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).nIn(1000).nOut(outputNum) .activation("softmax").weightInit(WeightInit.XAVIER).build()) .build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq))); log.info("Train model...."); model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq))); while (mnistIter.hasNext()) { mnist = mnistIter.next(); trainTest = mnist.splitTestAndTrain(splitTrainNum, new Random(seed)); // train set that is the result trainInput = trainTest.getTrain(); // get feature matrix and labels for training testInput.add(trainTest.getTest().getFeatureMatrix()); testLabels.add(trainTest.getTest().getLabels()); model.fit(trainInput); } log.info("Evaluate model...."); Evaluation eval = new Evaluation(outputNum); for (int i = 0; i < testInput.size(); i++) { INDArray output = model.output(testInput.get(i)); eval.eval(testLabels.get(i), output); } log.info(eval.stats()); log.info("****************Example finished********************"); }