List of usage examples for org.deeplearning4j.nn.conf.layers SubsamplingLayer.Builder SubsamplingLayer.Builder
public Builder(org.deeplearning4j.nn.conf.layers.PoolingType poolingType)
From source file:com.heatonresearch.aifh.examples.ann.LearnDigitsConv.java
License:Apache License
/** * The main method./*from ww w.ja v a 2s . c o m*/ * @param args Not used. */ public static void main(String[] args) { try { int seed = 43; double learningRate = 1e-2; int nEpochs = 50; int batchSize = 500; int channels = 1; // Setup training data. System.out.println("Please wait, reading MNIST training data."); String dir = System.getProperty("user.dir"); MNISTReader trainingReader = MNIST.loadMNIST(dir, true); MNISTReader validationReader = MNIST.loadMNIST(dir, false); DataSet trainingSet = trainingReader.getData(); DataSet validationSet = validationReader.getData(); DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize); DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationReader.getNumRows()); System.out.println("Training set size: " + trainingReader.getNumImages()); System.out.println("Validation set size: " + validationReader.getNumImages()); int numOutputs = 10; // Create neural network. MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).iterations(1) .regularization(true).l2(0.0005).learningRate(0.01).weightInit(WeightInit.XAVIER) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS) .momentum(0.9).list(4) .layer(0, new ConvolutionLayer.Builder(5, 5).nIn(channels).stride(1, 1).nOut(20).dropOut(0.5) .activation("relu").build()) .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2) .stride(2, 2).build()) .layer(2, new DenseLayer.Builder().activation("relu").nOut(500).build()) .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10) .activation("softmax").build()) .backprop(true).pretrain(false); new ConvolutionLayerSetup(builder, 28, 28, 1); MultiLayerConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.setListeners(new ScoreIterationListener(1)); // Define when we want to stop training. EarlyStoppingModelSaver saver = new InMemoryModelSaver(); EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder() //.epochTerminationConditions(new MaxEpochsTerminationCondition(10)) .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5)) .evaluateEveryNEpochs(1).scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score .modelSaver(saver).build(); EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator); // Train and display result. EarlyStoppingResult result = trainer.fit(); System.out.println("Termination reason: " + result.getTerminationReason()); System.out.println("Termination details: " + result.getTerminationDetails()); System.out.println("Total epochs: " + result.getTotalEpochs()); System.out.println("Best epoch number: " + result.getBestModelEpoch()); System.out.println("Score at best epoch: " + result.getBestModelScore()); model = saver.getBestModel(); // Evaluate Evaluation eval = new Evaluation(numOutputs); validationSetIterator.reset(); for (int i = 0; i < validationSet.numExamples(); i++) { DataSet t = validationSet.get(i); INDArray features = t.getFeatureMatrix(); INDArray labels = t.getLabels(); INDArray predicted = model.output(features, false); eval.eval(labels, predicted); } //Print the evaluation statistics System.out.println(eval.stats()); } catch (Exception ex) { ex.printStackTrace(); } }