Example usage for org.deeplearning4j.nn.conf MultiLayerConfiguration fromJson

List of usage examples for org.deeplearning4j.nn.conf MultiLayerConfiguration fromJson

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.conf MultiLayerConfiguration fromJson.

Prototype

public static MultiLayerConfiguration fromJson(String json) 

Source Link

Document

Create a neural net configuration from json

Usage

From source file:org.audiveris.omr.classifier.ModelSystemSerializer.java

License:Open Source License

/**
 * Load a multi layer network from a file system
 *
 * @param root        the root path of file system
 * @param loadUpdater true for loading updater
 * @return the loaded multi layer network
 * @throws IOException//  w w  w .  j  ava2  s. c o m
 */
public static MultiLayerNetwork restoreMultiLayerNetwork(Path root, boolean loadUpdater) throws IOException {
    boolean gotConfig = false;
    boolean gotCoefficients = false;
    boolean gotOldUpdater = false;
    boolean gotUpdaterState = false;
    boolean gotPreProcessor = false;

    String json = "";
    INDArray params = null;
    Updater updater = null;
    INDArray updaterState = null;
    DataSetPreProcessor preProcessor = null;

    final Path config = root.resolve("configuration.json");

    if (Files.exists(config)) {
        //restoring configuration
        InputStream stream = Files.newInputStream(config);
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
        String line = "";
        StringBuilder js = new StringBuilder();

        while ((line = reader.readLine()) != null) {
            js.append(line).append("\n");
        }

        json = js.toString();

        reader.close();
        stream.close();
        gotConfig = true;
    }

    final Path coefficients = root.resolve("coefficients.bin");

    if (Files.exists(coefficients)) {
        InputStream stream = Files.newInputStream(coefficients);
        DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
        params = Nd4j.read(dis);

        dis.close();
        gotCoefficients = true;
    }

    if (loadUpdater) {
        //This can be removed a few releases after 0.4.1...
        final Path oldUpdaters = root.resolve(OLD_UPDATER_BIN);

        if (Files.exists(oldUpdaters)) {
            InputStream stream = Files.newInputStream(oldUpdaters);
            ObjectInputStream ois = new ObjectInputStream(stream);

            try {
                updater = (Updater) ois.readObject();
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }

            gotOldUpdater = true;
        }

        final Path updaterStateEntry = root.resolve(UPDATER_BIN);

        if (updaterStateEntry != null) {
            InputStream stream = Files.newInputStream(updaterStateEntry);
            DataInputStream dis = new DataInputStream(stream);
            updaterState = Nd4j.read(dis);

            dis.close();
            gotUpdaterState = true;
        }
    }

    final Path prep = root.resolve("preprocessor.bin");

    if (Files.exists(prep)) {
        InputStream stream = Files.newInputStream(prep);
        ObjectInputStream ois = new ObjectInputStream(stream);

        try {
            preProcessor = (DataSetPreProcessor) ois.readObject();
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }

        gotPreProcessor = true;
    }

    if (gotConfig && gotCoefficients) {
        MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
        MultiLayerNetwork network = new MultiLayerNetwork(confFromJson);
        network.init(params, false);

        if (gotUpdaterState && (updaterState != null)) {
            network.getUpdater().setStateViewArray(network, updaterState, false);
        } else if (gotOldUpdater && (updater != null)) {
            network.setUpdater(updater);
        }

        return network;
    } else {
        throw new IllegalStateException("Model wasn't found within file: gotConfig: [" + gotConfig
                + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
    }
}

From source file:org.knime.ext.dl4j.base.util.DLModelPortObjectUtils.java

License:Open Source License

/**
 * Loads a {@link DLModelPortObject} from the specified {@link ZipInputStream}. Supports both deserialization of old
 * and new format./*from  w w w. ja v a 2  s. c o  m*/
 *
 * @param inStream the stream to load from
 * @return the loaded {@link DLModelPortObject}
 * @throws IOException
 */
@SuppressWarnings("resource")
public static DLModelPortObject loadPortFromZip(final ZipInputStream inStream) throws IOException {
    final List<Layer> layers = new ArrayList<>();

    //old model format
    INDArray mln_params = null;
    MultiLayerConfiguration mln_config = null;
    org.deeplearning4j.nn.api.Updater updater = null;

    //new model format
    boolean mlnLoaded = false;
    boolean cgLoaded = false;
    MultiLayerNetwork mlnFromModelSerializer = null;
    ComputationGraph cgFromModelSerializer = null;

    ZipEntry entry;

    while ((entry = inStream.getNextEntry()) != null) {
        // read layers
        if (entry.getName().matches("layer[0123456789]+")) {
            final String read = readStringFromZipStream(inStream);
            Layer l = NeuralNetConfiguration.fromJson(read).getLayer();

            if (l instanceof BaseLayer) {
                BaseLayer bl = (BaseLayer) l;
                /* Compatibility issue between dl4j 0.6 and 0.8 due to API change. Activations changed from
                 * Strings to an interface. Therefore, if a model was saved with 0.6 the corresponding member
                 * of the layer object will contain null after 'NeuralNetConfiguration.fromJson'. Old method to
                 * retrieve String representation of the activation function was removed. Therefore, we parse
                 * the old activation from the json ourself and map it to the new Activation. */
                if (bl.getActivationFn() == null) {
                    Optional<Activation> layerActivation = DL4JVersionUtils.parseLayerActivationFromJson(read);

                    if (layerActivation.isPresent()) {
                        bl.setActivationFn(layerActivation.get().getActivationFunction());
                    }
                }
            }

            layers.add(l);

            // directly read MultiLayerNetwork, new format
        } else if (entry.getName().matches("mln_model")) {
            //stream must not be closed, ModelSerializer tries to close the stream
            CloseShieldInputStream shieldIs = new CloseShieldInputStream(inStream);
            mlnFromModelSerializer = ModelSerializer.restoreMultiLayerNetwork(shieldIs, true);
            mlnLoaded = true;

            // directly read MultiLayerNetwork, new format
        } else if (entry.getName().matches("cg_model")) {
            //stream must not be closed, ModelSerializer tries to close the stream
            CloseShieldInputStream shieldIs = new CloseShieldInputStream(inStream);
            cgFromModelSerializer = ModelSerializer.restoreComputationGraph(shieldIs, true);
            cgLoaded = true;

            // read MultilayerNetworkConfig, old format
        } else if (entry.getName().matches("mln_config")) {

            final String read = readStringFromZipStream(inStream);
            mln_config = MultiLayerConfiguration.fromJson(read.toString());

            // read params, old format
        } else if (entry.getName().matches("mln_params")) {
            try {
                mln_params = Nd4j.read(inStream);
            } catch (Exception e) {
                throw new IOException("Could not load network parameters. Please re-execute the Node.", e);
            }

            // read updater, old format
        } else if (entry.getName().matches("mln_updater")) {
            // stream must not be closed, even if an exception is thrown, because the wrapped stream must stay open
            final IgnoreIDObjectInputStream ois = new IgnoreIDObjectInputStream(inStream);
            try {
                updater = (org.deeplearning4j.nn.api.Updater) ois.readObject();
            } catch (final ClassNotFoundException e) {
                throw new IOException("Problem with updater loading: " + e.getMessage(), e);
            }
        }
    }

    if (mlnLoaded) {
        assert (!cgLoaded);
        return new DLModelPortObject(layers, mlnFromModelSerializer, null);
    } else if (cgLoaded) {
        assert (!mlnLoaded);
        return new DLModelPortObject(layers, cgFromModelSerializer, null);
    } else {
        return new DLModelPortObject(layers, buildMln(mln_config, updater, mln_params), null);
    }
}

From source file:vectorizer.ModelSerializer.java

public static MultiLayerNetwork restoreMultiLayerNetwork(File file) throws IOException {
    ZipFile zipFile = new ZipFile(file);

    boolean gotConfig = false;
    boolean gotCoefficients = false;
    boolean gotUpdater = false;

    String json = "";
    INDArray params = null;// w ww. j  ava  2s . c  om
    Updater updater = null;

    ZipEntry config = zipFile.getEntry("configuration.json");
    if (config != null) {
        //restoring configuration

        InputStream stream = zipFile.getInputStream(config);
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
        String line = "";
        StringBuilder js = new StringBuilder();
        while ((line = reader.readLine()) != null) {
            js.append(line).append("\n");
        }
        json = js.toString();

        reader.close();
        stream.close();
        gotConfig = true;
    }

    ZipEntry coefficients = zipFile.getEntry("coefficients.bin");
    if (coefficients != null) {
        InputStream stream = zipFile.getInputStream(coefficients);
        DataInputStream dis = new DataInputStream(stream);
        params = Nd4j.read(dis);

        dis.close();
        gotCoefficients = true;
    }

    ZipEntry updaters = zipFile.getEntry("updater.bin");
    if (updaters != null) {
        InputStream stream = zipFile.getInputStream(updaters);
        ObjectInputStream ois = new ObjectInputStream(stream);

        try {
            updater = (Updater) ois.readObject();
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }

        gotUpdater = true;
    }

    zipFile.close();

    if (gotConfig && gotCoefficients) {
        MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
        MultiLayerNetwork network = new MultiLayerNetwork(confFromJson);
        network.init();
        network.setParameters(params);

        if (gotUpdater && updater != null) {
            network.setUpdater(updater);
        }
        return network;
    } else
        throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
                + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdater + "]");
}