Example usage for org.deeplearning4j.util ModelSerializer restoreMultiLayerNetwork

List of usage examples for org.deeplearning4j.util ModelSerializer restoreMultiLayerNetwork

Introduction

In this page you can find the example usage for org.deeplearning4j.util ModelSerializer restoreMultiLayerNetwork.

Prototype

public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull String path, boolean loadUpdater)
        throws IOException 

Source Link

Document

Load a MultilayerNetwork model from a file

Usage

From source file:org.audiveris.omrdataset.train.Training.java

License:Open Source License

/**
 * Perform the training of the neural network.
 * <p>/*from   w ww  . j  a  v a  2s .c o m*/
 * Before training is launched, if the network model exists on disk it is reloaded, otherwise a
 * brand new one is created.
 *
 * @throws Exception in case of IO problem or interruption
 */
public void process() throws Exception {
    Files.createDirectories(MISTAKES_PATH);

    int nChannels = 1; // Number of input channels
    int batchSize = 64; // Batch size
    int nEpochs = 1; //3; //10; //2; // Number of training epochs
    int iterations = 1; // 2; //10; // Number of training iterations
    int seed = 123; //

    // Pixel norms
    NormalizerStandardize normalizer = NormalizerSerializer.getDefault().restore(PIXELS_PATH.toFile());

    // Get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int labelIndex = CONTEXT_WIDTH * CONTEXT_HEIGHT; // format: all cells then label
    int numLinesToSkip = 1; // Because of header comment line
    String delimiter = ",";

    RecordReader trainRecordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    trainRecordReader.initialize(new FileSplit(FEATURES_PATH.toFile()));
    logger.info("Getting dataset from {} ...", FEATURES_PATH);

    RecordReaderDataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, batchSize,
            labelIndex, numClasses, -1);
    trainIter.setCollectMetaData(true); //Instruct the iterator to collect metadata, and store it in the DataSet objects

    RecordReader testRecordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    testRecordReader.initialize(new FileSplit(FEATURES_PATH.toFile()));

    RecordReaderDataSetIterator testIter = new RecordReaderDataSetIterator(testRecordReader, batchSize,
            labelIndex, numClasses, -1);
    testIter.setCollectMetaData(true); //Instruct the iterator to collect metadata, and store it in the DataSet objects

    // Normalization
    DataSetPreProcessor preProcessor = new MyPreProcessor(normalizer);
    trainIter.setPreProcessor(preProcessor);
    testIter.setPreProcessor(preProcessor);

    if (false) {
        System.out.println("\n  +++++ Test Set Examples MetaData +++++");

        while (testIter.hasNext()) {
            DataSet ds = testIter.next();
            List<RecordMetaData> testMetaData = ds.getExampleMetaData(RecordMetaData.class);

            for (RecordMetaData recordMetaData : testMetaData) {
                System.out.println(recordMetaData.getLocation());
            }
        }

        testIter.reset();
    }

    final MultiLayerNetwork model;

    if (Files.exists(MODEL_PATH)) {
        model = ModelSerializer.restoreMultiLayerNetwork(MODEL_PATH.toFile(), false);
        logger.info("Model restored from {}", MODEL_PATH.toAbsolutePath());
    } else {
        logger.info("Building model from scratch");

        MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() //
                .seed(seed) //
                .iterations(iterations) //
                .regularization(true) //
                .l2(0.0005) //
                .learningRate(.002) // HB: was .01 initially
                //.biasLearningRate(0.02)
                //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
                .weightInit(WeightInit.XAVIER) //
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) //
                .updater(Updater.NESTEROVS).momentum(0.9) //
                .list() //
                .layer(0, new ConvolutionLayer.Builder(5, 5) //
                        .name("C0") //
                        .nIn(nChannels) //
                        .stride(1, 1) //
                        .nOut(20) //
                        .activation(Activation.IDENTITY) //
                        .build()) //
                .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) //
                        .name("S1") //
                        .kernelSize(2, 2) //
                        .stride(2, 2) //
                        .build()) //
                .layer(2, new ConvolutionLayer.Builder(5, 5) //
                        .name("C2") //
                        .stride(1, 1) //
                        .nOut(50) //
                        .activation(Activation.IDENTITY) //
                        .build()) //
                .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) //
                        .name("S3") //
                        .kernelSize(2, 2) //
                        .stride(2, 2) //
                        .build()) //
                .layer(4, new DenseLayer.Builder() //
                        .name("D4") //
                        .nOut(500) //
                        .activation(Activation.RELU) //
                        .build()) //
                .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) //
                        .name("O5") //
                        .nOut(numClasses) //
                        .activation(Activation.SOFTMAX) //
                        .build()) //
                .setInputType(InputType.convolutionalFlat(CONTEXT_HEIGHT, CONTEXT_WIDTH, 1));

        MultiLayerConfiguration conf = builder.build();
        model = new MultiLayerNetwork(conf);
        model.init();
    }

    // Prepare monitoring
    UIServer uiServer = null;

    try {
        if (true) {
            //Initialize the user interface backend
            uiServer = UIServer.getInstance();

            //Configure where the network information (gradients, score vs. time etc) is to be stored. Here: store in memory.
            StatsStorage statsStorage = new InMemoryStatsStorage(); //Alternative: new FileStatsStorage(File), for saving and loading later

            //Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
            uiServer.attach(statsStorage);

            //Then add the StatsListener to collect this information from the network, as it trains
            model.setListeners(new StatsListener(statsStorage), new ScoreIterationListener(10));
        } else {
            model.setListeners(new ScoreIterationListener(10));
        }

        logger.info("Training model...");

        for (int epoch = 1; epoch <= nEpochs; epoch++) {
            Path epochFolder = Main.cli.mistakes ? MISTAKES_PATH.resolve("epoch#" + epoch) : null;
            long start = System.currentTimeMillis();
            model.fit(trainIter);

            long stop = System.currentTimeMillis();
            double dur = stop - start;
            logger.info(String.format("*** End epoch#%d, time: %.0f sec", epoch, dur / 1000));

            // Save model
            ModelSerializer.writeModel(model, MODEL_PATH.toFile(), false);
            ModelSerializer.addNormalizerToModel(MODEL_PATH.toFile(), normalizer);
            logger.info("Model+normalizer stored as {}", MODEL_PATH.toAbsolutePath());
            //
            //                logger.info("Evaluating model...");
            //
            //                Evaluation eval = new Evaluation(OmrShapes.NAMES);
            //
            //                while (testIter.hasNext()) {
            //                    DataSet ds = testIter.next();
            //                    List<RecordMetaData> testMetaData = ds.getExampleMetaData(RecordMetaData.class);
            //                    INDArray output = model.output(ds.getFeatureMatrix(), false);
            //                    eval.eval(ds.getLabels(), output, testMetaData);
            //                }
            //
            //                System.out.println(eval.stats());
            //                testIter.reset();
            //
            //                //Get a list of prediction errors, from the Evaluation object
            //                //Prediction errors like this are only available after calling iterator.setCollectMetaData(true)
            //                List<Prediction> mistakes = eval.getPredictionErrors();
            //                logger.info("Epoch#{} Prediction Errors: {}", epoch, mistakes.size());
            //
            //                //We can also load a subset of the data, to a DataSet object:
            //                //Here we load the raw data:
            //                List<RecordMetaData> predictionErrorMetaData = new ArrayList<RecordMetaData>();
            //
            //                for (Prediction p : mistakes) {
            //                    predictionErrorMetaData.add(p.getRecordMetaData(RecordMetaData.class));
            //                }
            //
            //                List<Record> predictionErrorRawData = testRecordReader.loadFromMetaData(
            //                        predictionErrorMetaData);
            //
            //                for (int ie = 0; ie < mistakes.size(); ie++) {
            //                    Prediction p = mistakes.get(ie);
            //                    List<Writable> rawData = predictionErrorRawData.get(ie).getRecord();
            //                    saveMistake(p, rawData, epochFolder);
            //                }
            //
            //
            //                // To avoid long useless sessions...
            //                if (mistakes.isEmpty()) {
            //                    logger.info("No mistakes left, training stopped.");
            //
            //                    break;
            //                }
        }
    } finally {
        // Stop monitoring
        if (uiServer != null) {
            uiServer.stop();
        }
    }

    logger.info("****************Example finished********************");
}

From source file:org.knime.ext.dl4j.base.util.DLModelPortObjectUtils.java

License:Open Source License

/**
 * Loads a {@link DLModelPortObject} from the specified {@link ZipInputStream}. Supports both deserialization of old
 * and new format./* w  ww .  j a v  a2s .  com*/
 *
 * @param inStream the stream to load from
 * @return the loaded {@link DLModelPortObject}
 * @throws IOException
 */
@SuppressWarnings("resource")
public static DLModelPortObject loadPortFromZip(final ZipInputStream inStream) throws IOException {
    final List<Layer> layers = new ArrayList<>();

    //old model format
    INDArray mln_params = null;
    MultiLayerConfiguration mln_config = null;
    org.deeplearning4j.nn.api.Updater updater = null;

    //new model format
    boolean mlnLoaded = false;
    boolean cgLoaded = false;
    MultiLayerNetwork mlnFromModelSerializer = null;
    ComputationGraph cgFromModelSerializer = null;

    ZipEntry entry;

    while ((entry = inStream.getNextEntry()) != null) {
        // read layers
        if (entry.getName().matches("layer[0123456789]+")) {
            final String read = readStringFromZipStream(inStream);
            Layer l = NeuralNetConfiguration.fromJson(read).getLayer();

            if (l instanceof BaseLayer) {
                BaseLayer bl = (BaseLayer) l;
                /* Compatibility issue between dl4j 0.6 and 0.8 due to API change. Activations changed from
                 * Strings to an interface. Therefore, if a model was saved with 0.6 the corresponding member
                 * of the layer object will contain null after 'NeuralNetConfiguration.fromJson'. Old method to
                 * retrieve String representation of the activation function was removed. Therefore, we parse
                 * the old activation from the json ourself and map it to the new Activation. */
                if (bl.getActivationFn() == null) {
                    Optional<Activation> layerActivation = DL4JVersionUtils.parseLayerActivationFromJson(read);

                    if (layerActivation.isPresent()) {
                        bl.setActivationFn(layerActivation.get().getActivationFunction());
                    }
                }
            }

            layers.add(l);

            // directly read MultiLayerNetwork, new format
        } else if (entry.getName().matches("mln_model")) {
            //stream must not be closed, ModelSerializer tries to close the stream
            CloseShieldInputStream shieldIs = new CloseShieldInputStream(inStream);
            mlnFromModelSerializer = ModelSerializer.restoreMultiLayerNetwork(shieldIs, true);
            mlnLoaded = true;

            // directly read MultiLayerNetwork, new format
        } else if (entry.getName().matches("cg_model")) {
            //stream must not be closed, ModelSerializer tries to close the stream
            CloseShieldInputStream shieldIs = new CloseShieldInputStream(inStream);
            cgFromModelSerializer = ModelSerializer.restoreComputationGraph(shieldIs, true);
            cgLoaded = true;

            // read MultilayerNetworkConfig, old format
        } else if (entry.getName().matches("mln_config")) {

            final String read = readStringFromZipStream(inStream);
            mln_config = MultiLayerConfiguration.fromJson(read.toString());

            // read params, old format
        } else if (entry.getName().matches("mln_params")) {
            try {
                mln_params = Nd4j.read(inStream);
            } catch (Exception e) {
                throw new IOException("Could not load network parameters. Please re-execute the Node.", e);
            }

            // read updater, old format
        } else if (entry.getName().matches("mln_updater")) {
            // stream must not be closed, even if an exception is thrown, because the wrapped stream must stay open
            final IgnoreIDObjectInputStream ois = new IgnoreIDObjectInputStream(inStream);
            try {
                updater = (org.deeplearning4j.nn.api.Updater) ois.readObject();
            } catch (final ClassNotFoundException e) {
                throw new IOException("Problem with updater loading: " + e.getMessage(), e);
            }
        }
    }

    if (mlnLoaded) {
        assert (!cgLoaded);
        return new DLModelPortObject(layers, mlnFromModelSerializer, null);
    } else if (cgLoaded) {
        assert (!mlnLoaded);
        return new DLModelPortObject(layers, cgFromModelSerializer, null);
    } else {
        return new DLModelPortObject(layers, buildMln(mln_config, updater, mln_params), null);
    }
}

From source file:weka.classifiers.functions.Dl4jMlpClassifier.java

License:Open Source License

/**
 * Custom deserialization method// w  ww  .j  a  v  a2s. co  m
 *
 * @param ois the object input stream
 * @throws ClassNotFoundException
 * @throws IOException
 */
private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {

    // default deserialization
    ois.defaultReadObject();

    // restore the model
    if (m_replaceMissing != null) {
        ClassLoader origLoader = Thread.currentThread().getContextClassLoader();
        try {
            Thread.currentThread().setContextClassLoader(this.getClass().getClassLoader());
            m_model = ModelSerializer.restoreMultiLayerNetwork(ois, false);
        } finally {
            Thread.currentThread().setContextClassLoader(origLoader);
        }
    }
}