org.deeplearning4j.util.ModelSerializerTest.java Source code

Java tutorial

Introduction

Here is the source code for org.deeplearning4j.util.ModelSerializerTest.java

Source

/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.util;

import lombok.val;
import org.apache.commons.lang3.SerializationUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.primitives.Pair;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.util.*;

import static org.junit.Assert.*;

/**
 * @author raver119@gmail.com
 */
public class ModelSerializerTest extends BaseDL4JTest {

    @Rule
    public TemporaryFolder tempDir = new TemporaryFolder();

    @Test
    public void testWriteMLNModel() throws Exception {
        int nIn = 5;
        int nOut = 6;

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).l2(0.01)
                .updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list()
                .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build())
                .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder()
                        .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build())
                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        File tempFile = tempDir.newFile();

        ModelSerializer.writeModel(net, tempFile, true);

        MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile);

        assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson());
        assertEquals(net.params(), network.params());
        assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
    }

    @Test
    public void testWriteMlnModelInputStream() throws Exception {
        int nIn = 5;
        int nOut = 6;

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).l2(0.01)
                .updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list()
                .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build())
                .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder()
                        .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build())
                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        File tempFile = tempDir.newFile();
        FileOutputStream fos = new FileOutputStream(tempFile);

        ModelSerializer.writeModel(net, fos, true);

        // checking adding of DataNormalization to the model file

        NormalizerMinMaxScaler scaler = new NormalizerMinMaxScaler();
        DataSetIterator iter = new IrisDataSetIterator(150, 150);
        scaler.fit(iter);

        ModelSerializer.addNormalizerToModel(tempFile, scaler);

        NormalizerMinMaxScaler restoredScaler = ModelSerializer.restoreNormalizerFromFile(tempFile);

        assertNotEquals(null, scaler.getMax());
        assertEquals(scaler.getMax(), restoredScaler.getMax());
        assertEquals(scaler.getMin(), restoredScaler.getMin());

        FileInputStream fis = new FileInputStream(tempFile);

        MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(fis);

        assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson());
        assertEquals(net.params(), network.params());
        assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
    }

    @Test
    public void testWriteCGModel() throws Exception {
        ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1))
                .graphBuilder().addInputs("in")
                .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in")
                .addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3)
                        .activation(Activation.SOFTMAX).build(), "dense")
                .setOutputs("out").build();

        ComputationGraph cg = new ComputationGraph(config);
        cg.init();

        File tempFile = tempDir.newFile();

        ModelSerializer.writeModel(cg, tempFile, true);

        ComputationGraph network = ModelSerializer.restoreComputationGraph(tempFile);

        assertEquals(network.getConfiguration().toJson(), cg.getConfiguration().toJson());
        assertEquals(cg.params(), network.params());
        assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
    }

    @Test
    public void testWriteCGModelInputStream() throws Exception {
        ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1))
                .graphBuilder().addInputs("in")
                .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in")
                .addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3)
                        .activation(Activation.SOFTMAX).build(), "dense")
                .setOutputs("out").build();

        ComputationGraph cg = new ComputationGraph(config);
        cg.init();

        File tempFile = tempDir.newFile();

        ModelSerializer.writeModel(cg, tempFile, true);
        FileInputStream fis = new FileInputStream(tempFile);

        ComputationGraph network = ModelSerializer.restoreComputationGraph(fis);

        assertEquals(network.getConfiguration().toJson(), cg.getConfiguration().toJson());
        assertEquals(cg.params(), network.params());
        assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
    }

    private DataSet trivialDataSet() {
        INDArray inputs = Nd4j.create(new float[] { 1.0f, 2.0f, 3.0f });
        INDArray labels = Nd4j.create(new float[] { 4.0f, 5.0f, 6.0f });
        return new DataSet(inputs, labels);
    }

    private ComputationGraph simpleComputationGraph() {
        ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1))
                .graphBuilder().addInputs("in")
                .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in")
                .addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3)
                        .activation(Activation.SOFTMAX).build(), "dense")
                .setOutputs("out").build();

        return new ComputationGraph(config);
    }

    @Test
    public void testSaveRestoreNormalizerFromInputStream() throws Exception {
        DataSet dataSet = trivialDataSet();
        NormalizerStandardize norm = new NormalizerStandardize();
        norm.fit(dataSet);

        ComputationGraph cg = simpleComputationGraph();
        cg.init();

        File tempFile = tempDir.newFile();

        ModelSerializer.writeModel(cg, tempFile, true);

        ModelSerializer.addNormalizerToModel(tempFile, norm);
        FileInputStream fis = new FileInputStream(tempFile);

        NormalizerStandardize restored = ModelSerializer.restoreNormalizerFromInputStream(fis);

        assertNotEquals(null, restored);

        DataSet dataSet2 = dataSet.copy();

        norm.preProcess(dataSet2);
        assertNotEquals(dataSet.getFeatures(), dataSet2.getFeatures());

        restored.revert(dataSet2);
        assertEquals(dataSet.getFeatures(), dataSet2.getFeatures());
    }

    @Test
    public void testRestoreUnsavedNormalizerFromInputStream() throws Exception {
        DataSet dataSet = trivialDataSet();

        NormalizerStandardize norm = new NormalizerStandardize();
        norm.fit(dataSet);

        ComputationGraph cg = simpleComputationGraph();
        cg.init();

        File tempFile = tempDir.newFile();
        ModelSerializer.writeModel(cg, tempFile, true);

        FileInputStream fis = new FileInputStream(tempFile);

        NormalizerStandardize restored = ModelSerializer.restoreNormalizerFromInputStream(fis);

        assertEquals(null, restored);
    }

    @Test
    public void testInvalidLoading1() throws Exception {
        ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
                .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in")
                .addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                        .activation(Activation.SOFTMAX).nIn(2).nOut(3).build(), "dense")
                .setOutputs("out").build();

        ComputationGraph cg = new ComputationGraph(config);
        cg.init();

        File tempFile = tempDir.newFile();

        ModelSerializer.writeModel(cg, tempFile, true);

        try {
            ModelSerializer.restoreMultiLayerNetwork(tempFile);
            fail();
        } catch (Exception e) {
            String msg = e.getMessage();
            assertTrue(msg, msg.contains("JSON") && msg.contains("restoreComputationGraph"));
        }
    }

    @Test
    public void testInvalidLoading2() throws Exception {
        int nIn = 5;
        int nOut = 6;

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).l2(0.01)
                .updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list()
                .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build())
                .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder()
                        .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build())
                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        File tempFile = tempDir.newFile("testInvalidLoading2.bin");

        ModelSerializer.writeModel(net, tempFile, true);

        try {
            ModelSerializer.restoreComputationGraph(tempFile);
            fail();
        } catch (Exception e) {
            String msg = e.getMessage();
            assertTrue(msg, msg.contains("JSON") && msg.contains("restoreMultiLayerNetwork"));
        }
    }

    @Test
    public void testInvalidStreamReuse() throws Exception {
        int nIn = 5;
        int nOut = 6;

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).list()
                .layer(new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build())
                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        DataSet dataSet = trivialDataSet();
        NormalizerStandardize norm = new NormalizerStandardize();
        norm.fit(dataSet);

        File tempFile = tempDir.newFile();
        ModelSerializer.writeModel(net, tempFile, true);
        ModelSerializer.addNormalizerToModel(tempFile, norm);

        InputStream is = new FileInputStream(tempFile);
        ModelSerializer.restoreMultiLayerNetwork(is);

        try {
            ModelSerializer.restoreNormalizerFromInputStream(is);
            fail("Expected exception");
        } catch (Exception e) {
            String msg = e.getMessage();
            assertTrue(msg, msg.contains("may have been closed"));
        }

        try {
            ModelSerializer.restoreMultiLayerNetwork(is);
            fail("Expected exception");
        } catch (Exception e) {
            String msg = e.getMessage();
            assertTrue(msg, msg.contains("may have been closed"));
        }

        //Also test reading  both model and normalizer from stream (correctly)
        Pair<MultiLayerNetwork, Normalizer> pair = ModelSerializer
                .restoreMultiLayerNetworkAndNormalizer(new FileInputStream(tempFile), true);
        assertEquals(net.params(), pair.getFirst().params());
        assertNotNull(pair.getSecond());
    }

    @Test
    public void testInvalidStreamReuseCG() throws Exception {
        int nIn = 5;
        int nOut = 6;

        ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01)
                .graphBuilder().addInputs("in").layer("0",
                        new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build(), "in")
                .setOutputs("0").build();

        ComputationGraph net = new ComputationGraph(conf);
        net.init();

        DataSet dataSet = trivialDataSet();
        NormalizerStandardize norm = new NormalizerStandardize();
        norm.fit(dataSet);

        File tempFile = tempDir.newFile();
        ModelSerializer.writeModel(net, tempFile, true);
        ModelSerializer.addNormalizerToModel(tempFile, norm);

        InputStream is = new FileInputStream(tempFile);
        ModelSerializer.restoreComputationGraph(is);

        try {
            ModelSerializer.restoreNormalizerFromInputStream(is);
            fail("Expected exception");
        } catch (Exception e) {
            String msg = e.getMessage();
            assertTrue(msg, msg.contains("may have been closed"));
        }

        try {
            ModelSerializer.restoreComputationGraph(is);
            fail("Expected exception");
        } catch (Exception e) {
            String msg = e.getMessage();
            assertTrue(msg, msg.contains("may have been closed"));
        }

        //Also test reading  both model and normalizer from stream (correctly)
        Pair<ComputationGraph, Normalizer> pair = ModelSerializer
                .restoreComputationGraphAndNormalizer(new FileInputStream(tempFile), true);
        assertEquals(net.params(), pair.getFirst().params());
        assertNotNull(pair.getSecond());
    }

    @Test
    public void testJavaSerde_1() throws Exception {
        int nIn = 5;
        int nOut = 6;

        ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01)
                .graphBuilder().addInputs("in")
                .layer("0", new OutputLayer.Builder().nIn(nIn).nOut(nOut).build(), "in").setOutputs("0")
                .validateOutputLayerConfig(false).build();

        ComputationGraph net = new ComputationGraph(conf);
        net.init();

        DataSet dataSet = trivialDataSet();
        NormalizerStandardize norm = new NormalizerStandardize();
        norm.fit(dataSet);

        val b = SerializationUtils.serialize(net);

        ComputationGraph restored = SerializationUtils.deserialize(b);

        assertEquals(net, restored);
    }

    @Test
    public void testJavaSerde_2() throws Exception {
        int nIn = 5;
        int nOut = 6;

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).list()
                .layer(0, new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build())
                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        DataSet dataSet = trivialDataSet();
        NormalizerStandardize norm = new NormalizerStandardize();
        norm.fit(dataSet);

        val b = SerializationUtils.serialize(net);

        MultiLayerNetwork restored = SerializationUtils.deserialize(b);

        assertEquals(net, restored);
    }

    @Test
    public void testPutGetObject() throws Exception {

        int nIn = 5;
        int nOut = 6;

        ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01)
                .graphBuilder().addInputs("in").layer("0",
                        new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build(), "in")
                .setOutputs("0").build();

        ComputationGraph net = new ComputationGraph(conf);
        net.init();

        File tempFile = tempDir.newFile();
        ModelSerializer.writeModel(net, tempFile, true);

        List<String> toWrite = Arrays.asList("zero", "one", "two");
        ModelSerializer.addObjectToFile(tempFile, "myLabels", toWrite);
        List<String> restored = ModelSerializer.getObjectFromFile(tempFile, "myLabels");
        assertEquals(toWrite, restored);

        Map<String, Object> someOtherData = new HashMap<>();
        someOtherData.put("x", new float[] { 0, 1, 2 });
        someOtherData.put("y", Nd4j.linspace(1, 10, 10));

        ModelSerializer.addObjectToFile(tempFile, "otherData.bin", someOtherData);

        Map<String, Object> dataRestored = ModelSerializer.getObjectFromFile(tempFile, "otherData.bin");
        assertEquals(someOtherData.keySet(), dataRestored.keySet());
        assertArrayEquals((float[]) someOtherData.get("x"), (float[]) dataRestored.get("x"), 0f);
        assertEquals(someOtherData.get("y"), dataRestored.get("y"));

        List<String> entries = ModelSerializer.listObjectsInFile(tempFile);
        assertEquals(2, entries.size());
        System.out.println(entries);
        assertTrue(entries.contains("myLabels"));
        assertTrue(entries.contains("otherData.bin"));

        ComputationGraph restoredNet = ModelSerializer.restoreComputationGraph(tempFile);
        assertEquals(net.params(), restoredNet.params());
    }
}