List of usage examples for org.deeplearning4j.datasets.datavec RecordReaderDataSetIterator reset
@Override public void reset()
From source file:org.audiveris.omrdataset.train.Training.java
License:Open Source License
/** * Perform the training of the neural network. * <p>/*from ww w . j av a2 s . c o m*/ * Before training is launched, if the network model exists on disk it is reloaded, otherwise a * brand new one is created. * * @throws Exception in case of IO problem or interruption */ public void process() throws Exception { Files.createDirectories(MISTAKES_PATH); int nChannels = 1; // Number of input channels int batchSize = 64; // Batch size int nEpochs = 1; //3; //10; //2; // Number of training epochs int iterations = 1; // 2; //10; // Number of training iterations int seed = 123; // // Pixel norms NormalizerStandardize normalizer = NormalizerSerializer.getDefault().restore(PIXELS_PATH.toFile()); // Get the dataset using the record reader. CSVRecordReader handles loading/parsing int labelIndex = CONTEXT_WIDTH * CONTEXT_HEIGHT; // format: all cells then label int numLinesToSkip = 1; // Because of header comment line String delimiter = ","; RecordReader trainRecordReader = new CSVRecordReader(numLinesToSkip, delimiter); trainRecordReader.initialize(new FileSplit(FEATURES_PATH.toFile())); logger.info("Getting dataset from {} ...", FEATURES_PATH); RecordReaderDataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, batchSize, labelIndex, numClasses, -1); trainIter.setCollectMetaData(true); //Instruct the iterator to collect metadata, and store it in the DataSet objects RecordReader testRecordReader = new CSVRecordReader(numLinesToSkip, delimiter); testRecordReader.initialize(new FileSplit(FEATURES_PATH.toFile())); RecordReaderDataSetIterator testIter = new RecordReaderDataSetIterator(testRecordReader, batchSize, labelIndex, numClasses, -1); testIter.setCollectMetaData(true); //Instruct the iterator to collect metadata, and store it in the DataSet objects // Normalization DataSetPreProcessor preProcessor = new MyPreProcessor(normalizer); trainIter.setPreProcessor(preProcessor); testIter.setPreProcessor(preProcessor); if (false) { System.out.println("\n +++++ Test Set Examples MetaData +++++"); while (testIter.hasNext()) { DataSet ds = testIter.next(); List<RecordMetaData> testMetaData = ds.getExampleMetaData(RecordMetaData.class); for (RecordMetaData recordMetaData : testMetaData) { System.out.println(recordMetaData.getLocation()); } } testIter.reset(); } final MultiLayerNetwork model; if (Files.exists(MODEL_PATH)) { model = ModelSerializer.restoreMultiLayerNetwork(MODEL_PATH.toFile(), false); logger.info("Model restored from {}", MODEL_PATH.toAbsolutePath()); } else { logger.info("Building model from scratch"); MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() // .seed(seed) // .iterations(iterations) // .regularization(true) // .l2(0.0005) // .learningRate(.002) // HB: was .01 initially //.biasLearningRate(0.02) //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75) .weightInit(WeightInit.XAVIER) // .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) // .updater(Updater.NESTEROVS).momentum(0.9) // .list() // .layer(0, new ConvolutionLayer.Builder(5, 5) // .name("C0") // .nIn(nChannels) // .stride(1, 1) // .nOut(20) // .activation(Activation.IDENTITY) // .build()) // .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) // .name("S1") // .kernelSize(2, 2) // .stride(2, 2) // .build()) // .layer(2, new ConvolutionLayer.Builder(5, 5) // .name("C2") // .stride(1, 1) // .nOut(50) // .activation(Activation.IDENTITY) // .build()) // .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) // .name("S3") // .kernelSize(2, 2) // .stride(2, 2) // .build()) // .layer(4, new DenseLayer.Builder() // .name("D4") // .nOut(500) // .activation(Activation.RELU) // .build()) // .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) // .name("O5") // .nOut(numClasses) // .activation(Activation.SOFTMAX) // .build()) // .setInputType(InputType.convolutionalFlat(CONTEXT_HEIGHT, CONTEXT_WIDTH, 1)); MultiLayerConfiguration conf = builder.build(); model = new MultiLayerNetwork(conf); model.init(); } // Prepare monitoring UIServer uiServer = null; try { if (true) { //Initialize the user interface backend uiServer = UIServer.getInstance(); //Configure where the network information (gradients, score vs. time etc) is to be stored. Here: store in memory. StatsStorage statsStorage = new InMemoryStatsStorage(); //Alternative: new FileStatsStorage(File), for saving and loading later //Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized uiServer.attach(statsStorage); //Then add the StatsListener to collect this information from the network, as it trains model.setListeners(new StatsListener(statsStorage), new ScoreIterationListener(10)); } else { model.setListeners(new ScoreIterationListener(10)); } logger.info("Training model..."); for (int epoch = 1; epoch <= nEpochs; epoch++) { Path epochFolder = Main.cli.mistakes ? MISTAKES_PATH.resolve("epoch#" + epoch) : null; long start = System.currentTimeMillis(); model.fit(trainIter); long stop = System.currentTimeMillis(); double dur = stop - start; logger.info(String.format("*** End epoch#%d, time: %.0f sec", epoch, dur / 1000)); // Save model ModelSerializer.writeModel(model, MODEL_PATH.toFile(), false); ModelSerializer.addNormalizerToModel(MODEL_PATH.toFile(), normalizer); logger.info("Model+normalizer stored as {}", MODEL_PATH.toAbsolutePath()); // // logger.info("Evaluating model..."); // // Evaluation eval = new Evaluation(OmrShapes.NAMES); // // while (testIter.hasNext()) { // DataSet ds = testIter.next(); // List<RecordMetaData> testMetaData = ds.getExampleMetaData(RecordMetaData.class); // INDArray output = model.output(ds.getFeatureMatrix(), false); // eval.eval(ds.getLabels(), output, testMetaData); // } // // System.out.println(eval.stats()); // testIter.reset(); // // //Get a list of prediction errors, from the Evaluation object // //Prediction errors like this are only available after calling iterator.setCollectMetaData(true) // List<Prediction> mistakes = eval.getPredictionErrors(); // logger.info("Epoch#{} Prediction Errors: {}", epoch, mistakes.size()); // // //We can also load a subset of the data, to a DataSet object: // //Here we load the raw data: // List<RecordMetaData> predictionErrorMetaData = new ArrayList<RecordMetaData>(); // // for (Prediction p : mistakes) { // predictionErrorMetaData.add(p.getRecordMetaData(RecordMetaData.class)); // } // // List<Record> predictionErrorRawData = testRecordReader.loadFromMetaData( // predictionErrorMetaData); // // for (int ie = 0; ie < mistakes.size(); ie++) { // Prediction p = mistakes.get(ie); // List<Writable> rawData = predictionErrorRawData.get(ie).getRecord(); // saveMistake(p, rawData, epochFolder); // } // // // // To avoid long useless sessions... // if (mistakes.isEmpty()) { // logger.info("No mistakes left, training stopped."); // // break; // } } } finally { // Stop monitoring if (uiServer != null) { uiServer.stop(); } } logger.info("****************Example finished********************"); }