Example usage for org.deeplearning4j.datasets.datavec RecordReaderDataSetIterator hasNext

List of usage examples for org.deeplearning4j.datasets.datavec RecordReaderDataSetIterator hasNext

Introduction

In this page you can find the example usage for org.deeplearning4j.datasets.datavec RecordReaderDataSetIterator hasNext.

Prototype

@Override
    public boolean hasNext() 

Source Link

Usage

From source file:org.audiveris.omrdataset.train.Training.java

License:Open Source License

/**
 * Perform the training of the neural network.
 * <p>/*  w w  w.  j a v a2  s  .  c  o  m*/
 * Before training is launched, if the network model exists on disk it is reloaded, otherwise a
 * brand new one is created.
 *
 * @throws Exception in case of IO problem or interruption
 */
public void process() throws Exception {
    Files.createDirectories(MISTAKES_PATH);

    int nChannels = 1; // Number of input channels
    int batchSize = 64; // Batch size
    int nEpochs = 1; //3; //10; //2; // Number of training epochs
    int iterations = 1; // 2; //10; // Number of training iterations
    int seed = 123; //

    // Pixel norms
    NormalizerStandardize normalizer = NormalizerSerializer.getDefault().restore(PIXELS_PATH.toFile());

    // Get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int labelIndex = CONTEXT_WIDTH * CONTEXT_HEIGHT; // format: all cells then label
    int numLinesToSkip = 1; // Because of header comment line
    String delimiter = ",";

    RecordReader trainRecordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    trainRecordReader.initialize(new FileSplit(FEATURES_PATH.toFile()));
    logger.info("Getting dataset from {} ...", FEATURES_PATH);

    RecordReaderDataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, batchSize,
            labelIndex, numClasses, -1);
    trainIter.setCollectMetaData(true); //Instruct the iterator to collect metadata, and store it in the DataSet objects

    RecordReader testRecordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    testRecordReader.initialize(new FileSplit(FEATURES_PATH.toFile()));

    RecordReaderDataSetIterator testIter = new RecordReaderDataSetIterator(testRecordReader, batchSize,
            labelIndex, numClasses, -1);
    testIter.setCollectMetaData(true); //Instruct the iterator to collect metadata, and store it in the DataSet objects

    // Normalization
    DataSetPreProcessor preProcessor = new MyPreProcessor(normalizer);
    trainIter.setPreProcessor(preProcessor);
    testIter.setPreProcessor(preProcessor);

    if (false) {
        System.out.println("\n  +++++ Test Set Examples MetaData +++++");

        while (testIter.hasNext()) {
            DataSet ds = testIter.next();
            List<RecordMetaData> testMetaData = ds.getExampleMetaData(RecordMetaData.class);

            for (RecordMetaData recordMetaData : testMetaData) {
                System.out.println(recordMetaData.getLocation());
            }
        }

        testIter.reset();
    }

    final MultiLayerNetwork model;

    if (Files.exists(MODEL_PATH)) {
        model = ModelSerializer.restoreMultiLayerNetwork(MODEL_PATH.toFile(), false);
        logger.info("Model restored from {}", MODEL_PATH.toAbsolutePath());
    } else {
        logger.info("Building model from scratch");

        MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() //
                .seed(seed) //
                .iterations(iterations) //
                .regularization(true) //
                .l2(0.0005) //
                .learningRate(.002) // HB: was .01 initially
                //.biasLearningRate(0.02)
                //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
                .weightInit(WeightInit.XAVIER) //
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) //
                .updater(Updater.NESTEROVS).momentum(0.9) //
                .list() //
                .layer(0, new ConvolutionLayer.Builder(5, 5) //
                        .name("C0") //
                        .nIn(nChannels) //
                        .stride(1, 1) //
                        .nOut(20) //
                        .activation(Activation.IDENTITY) //
                        .build()) //
                .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) //
                        .name("S1") //
                        .kernelSize(2, 2) //
                        .stride(2, 2) //
                        .build()) //
                .layer(2, new ConvolutionLayer.Builder(5, 5) //
                        .name("C2") //
                        .stride(1, 1) //
                        .nOut(50) //
                        .activation(Activation.IDENTITY) //
                        .build()) //
                .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) //
                        .name("S3") //
                        .kernelSize(2, 2) //
                        .stride(2, 2) //
                        .build()) //
                .layer(4, new DenseLayer.Builder() //
                        .name("D4") //
                        .nOut(500) //
                        .activation(Activation.RELU) //
                        .build()) //
                .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) //
                        .name("O5") //
                        .nOut(numClasses) //
                        .activation(Activation.SOFTMAX) //
                        .build()) //
                .setInputType(InputType.convolutionalFlat(CONTEXT_HEIGHT, CONTEXT_WIDTH, 1));

        MultiLayerConfiguration conf = builder.build();
        model = new MultiLayerNetwork(conf);
        model.init();
    }

    // Prepare monitoring
    UIServer uiServer = null;

    try {
        if (true) {
            //Initialize the user interface backend
            uiServer = UIServer.getInstance();

            //Configure where the network information (gradients, score vs. time etc) is to be stored. Here: store in memory.
            StatsStorage statsStorage = new InMemoryStatsStorage(); //Alternative: new FileStatsStorage(File), for saving and loading later

            //Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
            uiServer.attach(statsStorage);

            //Then add the StatsListener to collect this information from the network, as it trains
            model.setListeners(new StatsListener(statsStorage), new ScoreIterationListener(10));
        } else {
            model.setListeners(new ScoreIterationListener(10));
        }

        logger.info("Training model...");

        for (int epoch = 1; epoch <= nEpochs; epoch++) {
            Path epochFolder = Main.cli.mistakes ? MISTAKES_PATH.resolve("epoch#" + epoch) : null;
            long start = System.currentTimeMillis();
            model.fit(trainIter);

            long stop = System.currentTimeMillis();
            double dur = stop - start;
            logger.info(String.format("*** End epoch#%d, time: %.0f sec", epoch, dur / 1000));

            // Save model
            ModelSerializer.writeModel(model, MODEL_PATH.toFile(), false);
            ModelSerializer.addNormalizerToModel(MODEL_PATH.toFile(), normalizer);
            logger.info("Model+normalizer stored as {}", MODEL_PATH.toAbsolutePath());
            //
            //                logger.info("Evaluating model...");
            //
            //                Evaluation eval = new Evaluation(OmrShapes.NAMES);
            //
            //                while (testIter.hasNext()) {
            //                    DataSet ds = testIter.next();
            //                    List<RecordMetaData> testMetaData = ds.getExampleMetaData(RecordMetaData.class);
            //                    INDArray output = model.output(ds.getFeatureMatrix(), false);
            //                    eval.eval(ds.getLabels(), output, testMetaData);
            //                }
            //
            //                System.out.println(eval.stats());
            //                testIter.reset();
            //
            //                //Get a list of prediction errors, from the Evaluation object
            //                //Prediction errors like this are only available after calling iterator.setCollectMetaData(true)
            //                List<Prediction> mistakes = eval.getPredictionErrors();
            //                logger.info("Epoch#{} Prediction Errors: {}", epoch, mistakes.size());
            //
            //                //We can also load a subset of the data, to a DataSet object:
            //                //Here we load the raw data:
            //                List<RecordMetaData> predictionErrorMetaData = new ArrayList<RecordMetaData>();
            //
            //                for (Prediction p : mistakes) {
            //                    predictionErrorMetaData.add(p.getRecordMetaData(RecordMetaData.class));
            //                }
            //
            //                List<Record> predictionErrorRawData = testRecordReader.loadFromMetaData(
            //                        predictionErrorMetaData);
            //
            //                for (int ie = 0; ie < mistakes.size(); ie++) {
            //                    Prediction p = mistakes.get(ie);
            //                    List<Writable> rawData = predictionErrorRawData.get(ie).getRecord();
            //                    saveMistake(p, rawData, epochFolder);
            //                }
            //
            //
            //                // To avoid long useless sessions...
            //                if (mistakes.isEmpty()) {
            //                    logger.info("No mistakes left, training stopped.");
            //
            //                    break;
            //                }
        }
    } finally {
        // Stop monitoring
        if (uiServer != null) {
            uiServer.stop();
        }
    }

    logger.info("****************Example finished********************");
}