Example usage for org.deeplearning4j.nn.multilayer MultiLayerNetwork setUpdater

List of usage examples for org.deeplearning4j.nn.multilayer MultiLayerNetwork setUpdater

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.multilayer MultiLayerNetwork setUpdater.

Prototype

public void setUpdater(Updater updater) 

Source Link

Document

Set the updater for the MultiLayerNetwork

Usage

From source file:org.audiveris.omr.classifier.ModelSystemSerializer.java

License:Open Source License

/**
 * Load a multi layer network from a file system
 *
 * @param root        the root path of file system
 * @param loadUpdater true for loading updater
 * @return the loaded multi layer network
 * @throws IOException/*from  w  w  w  .ja v a2  s . c  o m*/
 */
public static MultiLayerNetwork restoreMultiLayerNetwork(Path root, boolean loadUpdater) throws IOException {
    boolean gotConfig = false;
    boolean gotCoefficients = false;
    boolean gotOldUpdater = false;
    boolean gotUpdaterState = false;
    boolean gotPreProcessor = false;

    String json = "";
    INDArray params = null;
    Updater updater = null;
    INDArray updaterState = null;
    DataSetPreProcessor preProcessor = null;

    final Path config = root.resolve("configuration.json");

    if (Files.exists(config)) {
        //restoring configuration
        InputStream stream = Files.newInputStream(config);
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
        String line = "";
        StringBuilder js = new StringBuilder();

        while ((line = reader.readLine()) != null) {
            js.append(line).append("\n");
        }

        json = js.toString();

        reader.close();
        stream.close();
        gotConfig = true;
    }

    final Path coefficients = root.resolve("coefficients.bin");

    if (Files.exists(coefficients)) {
        InputStream stream = Files.newInputStream(coefficients);
        DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
        params = Nd4j.read(dis);

        dis.close();
        gotCoefficients = true;
    }

    if (loadUpdater) {
        //This can be removed a few releases after 0.4.1...
        final Path oldUpdaters = root.resolve(OLD_UPDATER_BIN);

        if (Files.exists(oldUpdaters)) {
            InputStream stream = Files.newInputStream(oldUpdaters);
            ObjectInputStream ois = new ObjectInputStream(stream);

            try {
                updater = (Updater) ois.readObject();
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }

            gotOldUpdater = true;
        }

        final Path updaterStateEntry = root.resolve(UPDATER_BIN);

        if (updaterStateEntry != null) {
            InputStream stream = Files.newInputStream(updaterStateEntry);
            DataInputStream dis = new DataInputStream(stream);
            updaterState = Nd4j.read(dis);

            dis.close();
            gotUpdaterState = true;
        }
    }

    final Path prep = root.resolve("preprocessor.bin");

    if (Files.exists(prep)) {
        InputStream stream = Files.newInputStream(prep);
        ObjectInputStream ois = new ObjectInputStream(stream);

        try {
            preProcessor = (DataSetPreProcessor) ois.readObject();
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }

        gotPreProcessor = true;
    }

    if (gotConfig && gotCoefficients) {
        MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
        MultiLayerNetwork network = new MultiLayerNetwork(confFromJson);
        network.init(params, false);

        if (gotUpdaterState && (updaterState != null)) {
            network.getUpdater().setStateViewArray(network, updaterState, false);
        } else if (gotOldUpdater && (updater != null)) {
            network.setUpdater(updater);
        }

        return network;
    } else {
        throw new IllegalStateException("Model wasn't found within file: gotConfig: [" + gotConfig
                + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
    }
}

From source file:org.knime.ext.dl4j.base.nodes.learn.AbstractDLLearnerNodeModel.java

License:Open Source License

/**
 * Attempts to transfer the {@link Updater} from one {@link MultiLayerNetwork} to another. This is important if you
 * want to further train a pretrained net as some {@link Updater}s contain a history of gradients from training.
 *
 * @param from the network to get the updater from
 * @param to the network to transfer the updater to
 *//*from www .  java  2  s  .  c o m*/
protected void transferUpdater(final MultiLayerNetwork from, final MultiLayerNetwork to) {
    final org.deeplearning4j.nn.api.Updater updater = from.getUpdater();

    if (updater == null) {
        logger.warn("Could not transfer updater between nets as there is no updater set in the source net");
    } else {
        to.setUpdater(updater);
        logger.info("Successfully transfered updater between nets.");
    }
}

From source file:org.knime.ext.dl4j.base.util.DLModelPortObjectUtils.java

License:Open Source License

/**
 * Creates a {@link MultiLayerNetwork} from deserialized objects in the old format. This is now done implicitly by
 * the dl4j {@link ModelSerializer}.//from   w w  w .  ja  v  a2  s  .c o m
 *
 * @param config
 * @param updater
 * @param params
 * @return
 */
@Deprecated
private static MultiLayerNetwork buildMln(final MultiLayerConfiguration config,
        final org.deeplearning4j.nn.api.Updater updater, final INDArray params) {
    MultiLayerNetwork mln = null;

    if (config != null) {
        mln = new MultiLayerNetwork(config);
        mln.init();
        if (updater != null) {
            mln.setUpdater(updater);
        }
        if (params != null) {
            mln.setParams(params);
        }
    }
    return mln;
}

From source file:vectorizer.ModelSerializer.java

public static MultiLayerNetwork restoreMultiLayerNetwork(File file) throws IOException {
    ZipFile zipFile = new ZipFile(file);

    boolean gotConfig = false;
    boolean gotCoefficients = false;
    boolean gotUpdater = false;

    String json = "";
    INDArray params = null;/*from  w  w  w .ja  v  a 2 s.com*/
    Updater updater = null;

    ZipEntry config = zipFile.getEntry("configuration.json");
    if (config != null) {
        //restoring configuration

        InputStream stream = zipFile.getInputStream(config);
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
        String line = "";
        StringBuilder js = new StringBuilder();
        while ((line = reader.readLine()) != null) {
            js.append(line).append("\n");
        }
        json = js.toString();

        reader.close();
        stream.close();
        gotConfig = true;
    }

    ZipEntry coefficients = zipFile.getEntry("coefficients.bin");
    if (coefficients != null) {
        InputStream stream = zipFile.getInputStream(coefficients);
        DataInputStream dis = new DataInputStream(stream);
        params = Nd4j.read(dis);

        dis.close();
        gotCoefficients = true;
    }

    ZipEntry updaters = zipFile.getEntry("updater.bin");
    if (updaters != null) {
        InputStream stream = zipFile.getInputStream(updaters);
        ObjectInputStream ois = new ObjectInputStream(stream);

        try {
            updater = (Updater) ois.readObject();
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }

        gotUpdater = true;
    }

    zipFile.close();

    if (gotConfig && gotCoefficients) {
        MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
        MultiLayerNetwork network = new MultiLayerNetwork(confFromJson);
        network.init();
        network.setParameters(params);

        if (gotUpdater && updater != null) {
            network.setUpdater(updater);
        }
        return network;
    } else
        throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
                + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdater + "]");
}