Example usage for org.deeplearning4j.nn.api Model params

List of usage examples for org.deeplearning4j.nn.api Model params

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.api Model params.

Prototype

INDArray params();

Source Link

Document

Parameters of the model (if any)

Usage

From source file:org.audiveris.omr.classifier.ModelSystemSerializer.java

License:Open Source License

/**
 * Write a model to a file system/*w w  w  . ja  v  a 2s .c  o  m*/
 *
 * @param model       the model to save
 * @param root        the root path of file system
 * @param saveUpdater whether to save the updater for the model or not
 * @throws IOException
 */
public static void writeModel(Model model, Path root, boolean saveUpdater) throws IOException {
    {
        // save json first
        String json = "";

        if (model instanceof MultiLayerNetwork) {
            json = ((MultiLayerNetwork) model).getLayerWiseConfigurations().toJson();
        } else if (model instanceof ComputationGraph) {
            json = ((ComputationGraph) model).getConfiguration().toJson();
        }

        Path config = root.resolve("configuration.json");
        OutputStream bos = new BufferedOutputStream(Files.newOutputStream(config, CREATE));
        Writer writer = new OutputStreamWriter(bos, "UTF-8");
        writer.write(json);
        writer.close();
    }

    {
        Path coefficients = root.resolve("coefficients.bin");
        DataOutputStream dos = new DataOutputStream(
                new BufferedOutputStream(Files.newOutputStream(coefficients, CREATE)));
        Nd4j.write(model.params(), dos);
        dos.flush();
        dos.close();
    }

    if (saveUpdater) {
        INDArray updaterState = null;

        if (model instanceof MultiLayerNetwork) {
            updaterState = ((MultiLayerNetwork) model).getUpdater().getStateViewArray();
        } else if (model instanceof ComputationGraph) {
            updaterState = ((ComputationGraph) model).getUpdater().getStateViewArray();
        }

        if ((updaterState != null) && (updaterState.length() > 0)) {
            Path updater = root.resolve(UPDATER_BIN);
            DataOutputStream dos = new DataOutputStream(
                    new BufferedOutputStream(Files.newOutputStream(updater, CREATE)));
            Nd4j.write(updaterState, dos);
            dos.flush();
            dos.close();
        }
    }
}

From source file:vectorizer.ModelSerializer.java

public static void writeModel(Model model, OutputStream stream, boolean saveUpdater) throws IOException {
    ZipOutputStream zipfile = new ZipOutputStream(stream);

    // save json first
    String json = "";
    json = ((MultiLayerNetwork) model).getLayerWiseConfigurations().toJson();

    ZipEntry config = new ZipEntry("configuration.json");
    zipfile.putNextEntry(config);/*from w  w w .  j  av  a 2 s  .  co m*/

    writeEntry(new ByteArrayInputStream(json.getBytes()), zipfile);

    ZipEntry coefficients = new ZipEntry("coefficients.bin");
    zipfile.putNextEntry(coefficients);

    ByteArrayOutputStream bos = new ByteArrayOutputStream();
    DataOutputStream dos = new DataOutputStream(bos);
    Nd4j.write(model.params(), dos);
    dos.flush();
    dos.close();

    InputStream inputStream = new ByteArrayInputStream(bos.toByteArray());
    writeEntry(inputStream, zipfile);

    zipfile.flush();
    zipfile.close();
}