Example usage for org.deeplearning4j.nn.multilayer MultiLayerNetwork init

List of usage examples for org.deeplearning4j.nn.multilayer MultiLayerNetwork init

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.multilayer MultiLayerNetwork init.

Prototype

public void init(INDArray parameters, boolean cloneParametersArray) 

Source Link

Document

Initialize the MultiLayerNetwork, optionally with an existing parameters array.

Usage

From source file:org.audiveris.omr.classifier.ModelSystemSerializer.java

License:Open Source License

/**
 * Load a multi layer network from a file system
 *
 * @param root        the root path of file system
 * @param loadUpdater true for loading updater
 * @return the loaded multi layer network
 * @throws IOException//from ww  w . j  a v a2s . co m
 */
public static MultiLayerNetwork restoreMultiLayerNetwork(Path root, boolean loadUpdater) throws IOException {
    boolean gotConfig = false;
    boolean gotCoefficients = false;
    boolean gotOldUpdater = false;
    boolean gotUpdaterState = false;
    boolean gotPreProcessor = false;

    String json = "";
    INDArray params = null;
    Updater updater = null;
    INDArray updaterState = null;
    DataSetPreProcessor preProcessor = null;

    final Path config = root.resolve("configuration.json");

    if (Files.exists(config)) {
        //restoring configuration
        InputStream stream = Files.newInputStream(config);
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
        String line = "";
        StringBuilder js = new StringBuilder();

        while ((line = reader.readLine()) != null) {
            js.append(line).append("\n");
        }

        json = js.toString();

        reader.close();
        stream.close();
        gotConfig = true;
    }

    final Path coefficients = root.resolve("coefficients.bin");

    if (Files.exists(coefficients)) {
        InputStream stream = Files.newInputStream(coefficients);
        DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
        params = Nd4j.read(dis);

        dis.close();
        gotCoefficients = true;
    }

    if (loadUpdater) {
        //This can be removed a few releases after 0.4.1...
        final Path oldUpdaters = root.resolve(OLD_UPDATER_BIN);

        if (Files.exists(oldUpdaters)) {
            InputStream stream = Files.newInputStream(oldUpdaters);
            ObjectInputStream ois = new ObjectInputStream(stream);

            try {
                updater = (Updater) ois.readObject();
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }

            gotOldUpdater = true;
        }

        final Path updaterStateEntry = root.resolve(UPDATER_BIN);

        if (updaterStateEntry != null) {
            InputStream stream = Files.newInputStream(updaterStateEntry);
            DataInputStream dis = new DataInputStream(stream);
            updaterState = Nd4j.read(dis);

            dis.close();
            gotUpdaterState = true;
        }
    }

    final Path prep = root.resolve("preprocessor.bin");

    if (Files.exists(prep)) {
        InputStream stream = Files.newInputStream(prep);
        ObjectInputStream ois = new ObjectInputStream(stream);

        try {
            preProcessor = (DataSetPreProcessor) ois.readObject();
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }

        gotPreProcessor = true;
    }

    if (gotConfig && gotCoefficients) {
        MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
        MultiLayerNetwork network = new MultiLayerNetwork(confFromJson);
        network.init(params, false);

        if (gotUpdaterState && (updaterState != null)) {
            network.getUpdater().setStateViewArray(network, updaterState, false);
        } else if (gotOldUpdater && (updater != null)) {
            network.setUpdater(updater);
        }

        return network;
    } else {
        throw new IllegalStateException("Model wasn't found within file: gotConfig: [" + gotConfig
                + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
    }
}