Example usage for org.deeplearning4j.nn.multilayer MultiLayerNetwork setParameters

List of usage examples for org.deeplearning4j.nn.multilayer MultiLayerNetwork setParameters

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.multilayer MultiLayerNetwork setParameters.

Prototype

public void setParameters(INDArray params) 

Source Link

Document

See #setParams(INDArray)

Usage

From source file:examples.cnn.NetworkTrainer.java

License:Apache License

public void train(JavaRDD<DataSet> train, JavaRDD<DataSet> test) {

    int batchSize = 12 * cores;
    int lrCount = 0;
    double bestAccuracy = Double.MIN_VALUE;

    double learningRate = initialLearningRate;

    int trainCount = Long.valueOf(train.count()).intValue();
    log.info("Number of training images {}", trainCount);
    log.info("Number of test images {}", test.count());

    MultiLayerNetwork net = new MultiLayerNetwork(
            model.apply(learningRate, width, height, channels, numLabels));
    net.init();//from ww  w . j a va2 s  . c o  m

    Map<Integer, Double> acc = new HashMap<>();
    for (int i = 0; i < epochs; i++) {

        SparkDl4jMultiLayer sparkNetwork = networkToSparkNetwork.apply(net);
        final MultiLayerNetwork nn = sparkNetwork.fitDataSet(train, batchSize, trainCount, cores);
        log.info("Epoch {} completed", i);

        JavaPairRDD<Object, Object> predictionsAndLabels = test.mapToPair(
                ds -> new Tuple2<>(label(nn.output(ds.getFeatureMatrix(), false)), label(ds.getLabels())));
        MulticlassMetrics metrics = new MulticlassMetrics(predictionsAndLabels.rdd());
        double accuracy = 1.0 * predictionsAndLabels.filter(x -> x._1.equals(x._2)).count() / test.count();
        log.info("Epoch {} accuracy {} ", i, accuracy);
        acc.put(i, accuracy);
        predictionsAndLabels.take(10).forEach(t -> log.info("predicted {}, label {}", t._1, t._2));
        log.info("confusionMatrix {}", metrics.confusionMatrix());

        INDArray params = nn.params();
        if (accuracy > bestAccuracy) {
            bestAccuracy = accuracy;
            try {
                ModelSerializer.writeModel(nn, new File(workingDir, Double.toString(accuracy)), false);
            } catch (IOException e) {
                log.error("Error writing trained model", e);
            }
            lrCount = 0;
        } else {

            if (++lrCount % stepDecayTreshold == 0) {
                learningRate *= learningRateDecayFactor;
            }
            if (lrCount >= resetLearningRateThreshold) {
                lrCount = 0;
                learningRate = initialLearningRate;
            }
            if (learningRate < minimumLearningRate) {
                lrCount = 0;
                learningRate = initialLearningRate;
            }
            if (bestAccuracy - accuracy > downgradeAccuracyThreshold) {
                params = ModelLoader.load(workingDir, bestAccuracy);
            }
        }
        net = new MultiLayerNetwork(model.apply(learningRate, width, height, channels, numLabels));
        net.init();
        net.setParameters(params);
        log.info("Learning rate {} for epoch {}", learningRate, i + 1);
    }
    log.info("Training completed");

}

From source file:vectorizer.ModelSerializer.java

public static MultiLayerNetwork restoreMultiLayerNetwork(File file) throws IOException {
    ZipFile zipFile = new ZipFile(file);

    boolean gotConfig = false;
    boolean gotCoefficients = false;
    boolean gotUpdater = false;

    String json = "";
    INDArray params = null;/*from w  w w.  jav  a  2s .c  o m*/
    Updater updater = null;

    ZipEntry config = zipFile.getEntry("configuration.json");
    if (config != null) {
        //restoring configuration

        InputStream stream = zipFile.getInputStream(config);
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
        String line = "";
        StringBuilder js = new StringBuilder();
        while ((line = reader.readLine()) != null) {
            js.append(line).append("\n");
        }
        json = js.toString();

        reader.close();
        stream.close();
        gotConfig = true;
    }

    ZipEntry coefficients = zipFile.getEntry("coefficients.bin");
    if (coefficients != null) {
        InputStream stream = zipFile.getInputStream(coefficients);
        DataInputStream dis = new DataInputStream(stream);
        params = Nd4j.read(dis);

        dis.close();
        gotCoefficients = true;
    }

    ZipEntry updaters = zipFile.getEntry("updater.bin");
    if (updaters != null) {
        InputStream stream = zipFile.getInputStream(updaters);
        ObjectInputStream ois = new ObjectInputStream(stream);

        try {
            updater = (Updater) ois.readObject();
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }

        gotUpdater = true;
    }

    zipFile.close();

    if (gotConfig && gotCoefficients) {
        MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
        MultiLayerNetwork network = new MultiLayerNetwork(confFromJson);
        network.init();
        network.setParameters(params);

        if (gotUpdater && updater != null) {
            network.setUpdater(updater);
        }
        return network;
    } else
        throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
                + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdater + "]");
}