org.deeplearning4j.util.ModelSerializer.java Source code

Java tutorial

Introduction

Here is the source code for org.deeplearning4j.util.ModelSerializer.java

Source

/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.util;

import org.nd4j.shade.guava.io.Files;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.output.CloseShieldOutputStream;
import org.deeplearning4j.config.DL4JSystemProperties;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializer;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.primitives.Pair;

import java.io.*;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;

/**
 * Utility class suited to save/restore neural net models
 *
 * @author raver119@gmail.com
 */
@Slf4j
public class ModelSerializer {

    public static final String UPDATER_BIN = "updaterState.bin";
    public static final String NORMALIZER_BIN = "normalizer.bin";
    public static final String CONFIGURATION_JSON = "configuration.json";
    public static final String COEFFICIENTS_BIN = "coefficients.bin";
    public static final String NO_PARAMS_MARKER = "noParams.marker";
    public static final String PREPROCESSOR_BIN = "preprocessor.bin";

    private ModelSerializer() {
    }

    /**
     * Write a model to a file
     * @param model the model to write
     * @param file the file to write to
     * @param saveUpdater whether to save the updater or not
     * @throws IOException
     */
    public static void writeModel(@NonNull Model model, @NonNull File file, boolean saveUpdater)
            throws IOException {
        writeModel(model, file, saveUpdater, null);
    }

    /**
     * Write a model to a file
     * @param model the model to write
     * @param file the file to write to
     * @param saveUpdater whether to save the updater or not
     * @param dataNormalization the normalizer to save (optional)
     * @throws IOException
     */
    public static void writeModel(@NonNull Model model, @NonNull File file, boolean saveUpdater,
            DataNormalization dataNormalization) throws IOException {
        try (BufferedOutputStream stream = new BufferedOutputStream(new FileOutputStream(file))) {
            writeModel(model, stream, saveUpdater, dataNormalization);
        }
    }

    /**
     * Write a model to a file path
     * @param model the model to write
     * @param path the path to write to
     * @param saveUpdater whether to save the updater
     *                    or not
     * @throws IOException
     */
    public static void writeModel(@NonNull Model model, @NonNull String path, boolean saveUpdater)
            throws IOException {
        try (BufferedOutputStream stream = new BufferedOutputStream(new FileOutputStream(path))) {
            writeModel(model, stream, saveUpdater);
        }
    }

    /**
     * Write a model to an output stream
     * @param model the model to save
     * @param stream the output stream to write to
     * @param saveUpdater whether to save the updater for the model or not
     * @throws IOException
     */
    public static void writeModel(@NonNull Model model, @NonNull OutputStream stream, boolean saveUpdater)
            throws IOException {
        writeModel(model, stream, saveUpdater, null);
    }

    /**
     * Write a model to an output stream
     * @param model the model to save
     * @param stream the output stream to write to
     * @param saveUpdater whether to save the updater for the model or not
     * @param dataNormalization the normalizer ot save (may be null)
     * @throws IOException
     */
    public static void writeModel(@NonNull Model model, @NonNull OutputStream stream, boolean saveUpdater,
            DataNormalization dataNormalization) throws IOException {
        ZipOutputStream zipfile = new ZipOutputStream(new CloseShieldOutputStream(stream));

        // Save configuration as JSON
        String json = "";
        if (model instanceof MultiLayerNetwork) {
            json = ((MultiLayerNetwork) model).getLayerWiseConfigurations().toJson();
        } else if (model instanceof ComputationGraph) {
            json = ((ComputationGraph) model).getConfiguration().toJson();
        }

        ZipEntry config = new ZipEntry(CONFIGURATION_JSON);
        zipfile.putNextEntry(config);
        zipfile.write(json.getBytes());

        // Save parameters as binary
        ZipEntry coefficients = new ZipEntry(COEFFICIENTS_BIN);
        zipfile.putNextEntry(coefficients);
        DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(zipfile));
        INDArray params = model.params();
        if (params != null) {
            try {
                Nd4j.write(model.params(), dos);
            } finally {
                dos.flush();
            }
        } else {
            ZipEntry noParamsMarker = new ZipEntry(NO_PARAMS_MARKER);
            zipfile.putNextEntry(noParamsMarker);
        }

        if (saveUpdater) {
            INDArray updaterState = null;
            if (model instanceof MultiLayerNetwork) {
                updaterState = ((MultiLayerNetwork) model).getUpdater().getStateViewArray();
            } else if (model instanceof ComputationGraph) {
                updaterState = ((ComputationGraph) model).getUpdater().getStateViewArray();
            }

            if (updaterState != null && updaterState.length() > 0) {
                ZipEntry updater = new ZipEntry(UPDATER_BIN);
                zipfile.putNextEntry(updater);

                try {
                    Nd4j.write(updaterState, dos);
                } finally {
                    dos.flush();
                }
            }
        }

        if (dataNormalization != null) {
            // now, add our normalizer as additional entry
            ZipEntry nEntry = new ZipEntry(NORMALIZER_BIN);
            zipfile.putNextEntry(nEntry);
            NormalizerSerializer.getDefault().write(dataNormalization, zipfile);
        }

        dos.close();
        zipfile.close();
    }

    /**
     * Load a multi layer network from a file
     *
     * @param file the file to load from
     * @return the loaded multi layer network
     * @throws IOException
     */
    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file) throws IOException {
        return restoreMultiLayerNetwork(file, true);
    }

    /**
     * Load a multi layer network from a file
     *
     * @param file the file to load from
     * @return the loaded multi layer network
     * @throws IOException
     */
    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file, boolean loadUpdater)
            throws IOException {
        ZipFile zipFile = new ZipFile(file);

        boolean gotConfig = false;
        boolean gotCoefficients = false;
        boolean gotUpdaterState = false;
        boolean gotPreProcessor = false;

        String json = "";
        INDArray params = null;
        Updater updater = null;
        INDArray updaterState = null;
        DataSetPreProcessor preProcessor = null;

        ZipEntry config = zipFile.getEntry(CONFIGURATION_JSON);
        if (config != null) {
            //restoring configuration

            InputStream stream = zipFile.getInputStream(config);
            BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
            String line = "";
            StringBuilder js = new StringBuilder();
            while ((line = reader.readLine()) != null) {
                js.append(line).append("\n");
            }
            json = js.toString();

            reader.close();
            stream.close();
            gotConfig = true;
        }

        ZipEntry coefficients = zipFile.getEntry(COEFFICIENTS_BIN);
        if (coefficients != null) {
            if (coefficients.getSize() > 0) {
                InputStream stream = zipFile.getInputStream(coefficients);
                DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
                params = Nd4j.read(dis);

                dis.close();
                gotCoefficients = true;
            } else {
                ZipEntry noParamsMarker = zipFile.getEntry(NO_PARAMS_MARKER);
                gotCoefficients = (noParamsMarker != null);
            }
        }

        if (loadUpdater) {
            ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
            if (updaterStateEntry != null) {
                InputStream stream = zipFile.getInputStream(updaterStateEntry);
                DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
                updaterState = Nd4j.read(dis);

                dis.close();
                gotUpdaterState = true;
            }
        }

        ZipEntry prep = zipFile.getEntry(PREPROCESSOR_BIN);
        if (prep != null) {
            InputStream stream = zipFile.getInputStream(prep);
            ObjectInputStream ois = new ObjectInputStream(stream);

            try {
                preProcessor = (DataSetPreProcessor) ois.readObject();
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }

            gotPreProcessor = true;
        }

        zipFile.close();

        if (gotConfig && gotCoefficients) {
            MultiLayerConfiguration confFromJson;
            try {
                confFromJson = MultiLayerConfiguration.fromJson(json);
            } catch (Exception e) {
                ComputationGraphConfiguration cg;
                try {
                    cg = ComputationGraphConfiguration.fromJson(json);
                } catch (Exception e2) {
                    //Invalid, and not a compgraph
                    throw new RuntimeException(
                            "Error deserializing JSON MultiLayerConfiguration. Saved model JSON is"
                                    + " not a valid MultiLayerConfiguration",
                            e);
                }
                if (cg.getNetworkInputs() != null && cg.getVertices() != null) {
                    throw new RuntimeException(
                            "Error deserializing JSON MultiLayerConfiguration. Saved model appears to be "
                                    + "a ComputationGraph - use ModelSerializer.restoreComputationGraph instead");
                } else {
                    throw e;
                }
            }

            //Handle legacy config - no network DataType in config, in beta3 or earlier
            if (params != null)
                confFromJson.setDataType(params.dataType());
            MultiLayerNetwork network = new MultiLayerNetwork(confFromJson);
            network.init(params, false);

            if (gotUpdaterState && updaterState != null) {
                network.getUpdater().setStateViewArray(network, updaterState, false);
            }
            return network;
        } else
            throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
                    + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
    }

    /**
     * Load a MultiLayerNetwork from InputStream from an input stream<br>
     * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
     *
     * @param is the inputstream to load from
     * @return the loaded multi layer network
     * @throws IOException
     * @see #restoreMultiLayerNetworkAndNormalizer(InputStream, boolean)
     */
    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater)
            throws IOException {
        checkInputStream(is);

        File tmpFile = null;
        try {
            tmpFile = tempFileFromStream(is);
            return restoreMultiLayerNetwork(tmpFile, loadUpdater);
        } finally {
            if (tmpFile != null) {
                tmpFile.delete();
            }
        }
    }

    /**
     * Restore a multi layer network from an input stream<br>
     * * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
     *
     *
     * @param is the input stream to restore from
     * @return the loaded multi layer network
     * @throws IOException
     * @see #restoreMultiLayerNetworkAndNormalizer(InputStream, boolean)
     */
    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is) throws IOException {
        return restoreMultiLayerNetwork(is, true);
    }

    /**
     * Load a MultilayerNetwork model from a file
     *
     * @param path path to the model file, to get the computation graph from
     * @return the loaded computation graph
     *
     * @throws IOException
     */
    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull String path) throws IOException {
        return restoreMultiLayerNetwork(new File(path), true);
    }

    /**
     * Load a MultilayerNetwork model from a file
     * @param path path to the model file, to get the computation graph from
     * @return the loaded computation graph
     *
     * @throws IOException
     */
    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull String path, boolean loadUpdater)
            throws IOException {
        return restoreMultiLayerNetwork(new File(path), loadUpdater);
    }

    /**
     * Restore a MultiLayerNetwork and Normalizer (if present - null if not) from the InputStream.
     * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
     *
     * @param is          Input stream to read from
     * @param loadUpdater Whether to load the updater from the model or not
     * @return Model and normalizer, if present
     * @throws IOException If an error occurs when reading from the stream
     */
    public static Pair<MultiLayerNetwork, Normalizer> restoreMultiLayerNetworkAndNormalizer(@NonNull InputStream is,
            boolean loadUpdater) throws IOException {
        checkInputStream(is);

        File tmpFile = null;
        try {
            tmpFile = tempFileFromStream(is);
            return restoreMultiLayerNetworkAndNormalizer(tmpFile, loadUpdater);
        } finally {
            if (tmpFile != null) {
                tmpFile.delete();
            }
        }
    }

    /**
     * Restore a MultiLayerNetwork and Normalizer (if present - null if not) from a File
     *
     * @param file        File to read the model and normalizer from
     * @param loadUpdater Whether to load the updater from the model or not
     * @return Model and normalizer, if present
     * @throws IOException If an error occurs when reading from the File
     */
    public static Pair<MultiLayerNetwork, Normalizer> restoreMultiLayerNetworkAndNormalizer(@NonNull File file,
            boolean loadUpdater) throws IOException {
        MultiLayerNetwork net = restoreMultiLayerNetwork(file, loadUpdater);
        Normalizer norm = restoreNormalizerFromFile(file);
        return new Pair<>(net, norm);
    }

    /**
     * Load a computation graph from a file
     * @param path path to the model file, to get the computation graph from
     * @return the loaded computation graph
     *
     * @throws IOException
     */
    public static ComputationGraph restoreComputationGraph(@NonNull String path) throws IOException {
        return restoreComputationGraph(new File(path), true);
    }

    /**
     * Load a computation graph from a file
     * @param path path to the model file, to get the computation graph from
     * @return the loaded computation graph
     *
     * @throws IOException
     */
    public static ComputationGraph restoreComputationGraph(@NonNull String path, boolean loadUpdater)
            throws IOException {
        return restoreComputationGraph(new File(path), loadUpdater);
    }

    /**
     * Load a computation graph from a InputStream
     * @param is the inputstream to get the computation graph from
     * @return the loaded computation graph
     *
     * @throws IOException
     */
    public static ComputationGraph restoreComputationGraph(@NonNull InputStream is, boolean loadUpdater)
            throws IOException {
        checkInputStream(is);

        File tmpFile = null;
        try {
            tmpFile = tempFileFromStream(is);
            return restoreComputationGraph(tmpFile, loadUpdater);
        } finally {
            if (tmpFile != null) {
                tmpFile.delete();
            }
        }
    }

    /**
     * Load a computation graph from a InputStream
     * @param is the inputstream to get the computation graph from
     * @return the loaded computation graph
     *
     * @throws IOException
     */
    public static ComputationGraph restoreComputationGraph(@NonNull InputStream is) throws IOException {
        return restoreComputationGraph(is, true);
    }

    /**
     * Load a computation graph from a file
     * @param file the file to get the computation graph from
     * @return the loaded computation graph
     *
     * @throws IOException
     */
    public static ComputationGraph restoreComputationGraph(@NonNull File file) throws IOException {
        return restoreComputationGraph(file, true);
    }

    /**
     * Restore a ComputationGraph and Normalizer (if present - null if not) from the InputStream.
     * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
     *
     * @param is          Input stream to read from
     * @param loadUpdater Whether to load the updater from the model or not
     * @return Model and normalizer, if present
     * @throws IOException If an error occurs when reading from the stream
     */
    public static Pair<ComputationGraph, Normalizer> restoreComputationGraphAndNormalizer(@NonNull InputStream is,
            boolean loadUpdater) throws IOException {
        checkInputStream(is);

        File tmpFile = null;
        try {
            tmpFile = tempFileFromStream(is);
            return restoreComputationGraphAndNormalizer(tmpFile, loadUpdater);
        } finally {
            if (tmpFile != null) {
                tmpFile.delete();
            }
        }
    }

    /**
     * Restore a ComputationGraph and Normalizer (if present - null if not) from a File
     *
     * @param file        File to read the model and normalizer from
     * @param loadUpdater Whether to load the updater from the model or not
     * @return Model and normalizer, if present
     * @throws IOException If an error occurs when reading from the File
     */
    public static Pair<ComputationGraph, Normalizer> restoreComputationGraphAndNormalizer(@NonNull File file,
            boolean loadUpdater) throws IOException {
        ComputationGraph net = restoreComputationGraph(file, loadUpdater);
        Normalizer norm = restoreNormalizerFromFile(file);
        return new Pair<>(net, norm);
    }

    /**
     * Load a computation graph from a file
     * @param file the file to get the computation graph from
     * @return the loaded computation graph
     *
     * @throws IOException
     */
    public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater)
            throws IOException {
        ZipFile zipFile = new ZipFile(file);

        boolean gotConfig = false;
        boolean gotCoefficients = false;
        boolean gotUpdaterState = false;
        boolean gotPreProcessor = false;

        String json = "";
        INDArray params = null;
        INDArray updaterState = null;
        DataSetPreProcessor preProcessor = null;

        ZipEntry config = zipFile.getEntry(CONFIGURATION_JSON);
        if (config != null) {
            //restoring configuration

            InputStream stream = zipFile.getInputStream(config);
            BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
            String line = "";
            StringBuilder js = new StringBuilder();
            while ((line = reader.readLine()) != null) {
                js.append(line).append("\n");
            }
            json = js.toString();

            reader.close();
            stream.close();
            gotConfig = true;
        }

        ZipEntry coefficients = zipFile.getEntry(COEFFICIENTS_BIN);
        if (coefficients != null) {
            if (coefficients.getSize() > 0) {
                InputStream stream = zipFile.getInputStream(coefficients);
                DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
                params = Nd4j.read(dis);

                dis.close();
                gotCoefficients = true;
            } else {
                ZipEntry noParamsMarker = zipFile.getEntry(NO_PARAMS_MARKER);
                gotCoefficients = (noParamsMarker != null);
            }
        }

        if (loadUpdater) {
            ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
            if (updaterStateEntry != null) {
                InputStream stream = zipFile.getInputStream(updaterStateEntry);
                DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
                updaterState = Nd4j.read(dis);

                dis.close();
                gotUpdaterState = true;
            }
        }

        ZipEntry prep = zipFile.getEntry(PREPROCESSOR_BIN);
        if (prep != null) {
            InputStream stream = zipFile.getInputStream(prep);
            ObjectInputStream ois = new ObjectInputStream(stream);

            try {
                preProcessor = (DataSetPreProcessor) ois.readObject();
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }

            gotPreProcessor = true;
        }

        zipFile.close();

        if (gotConfig && gotCoefficients) {
            ComputationGraphConfiguration confFromJson;
            try {
                confFromJson = ComputationGraphConfiguration.fromJson(json);
                if (confFromJson.getNetworkInputs() == null
                        && (confFromJson.getVertices() == null || confFromJson.getVertices().size() == 0)) {
                    //May be deserialized correctly, but mostly with null fields
                    throw new RuntimeException("Invalid JSON - not a ComputationGraphConfiguration");
                }
            } catch (Exception e) {
                if (e.getMessage() != null && e.getMessage().contains("registerLegacyCustomClassesForJSON")) {
                    throw e;
                }
                try {
                    MultiLayerConfiguration.fromJson(json);
                } catch (Exception e2) {
                    //Invalid, and not a compgraph
                    throw new RuntimeException(
                            "Error deserializing JSON ComputationGraphConfiguration. Saved model JSON is"
                                    + " not a valid ComputationGraphConfiguration",
                            e);
                }
                throw new RuntimeException(
                        "Error deserializing JSON ComputationGraphConfiguration. Saved model appears to be "
                                + "a MultiLayerNetwork - use ModelSerializer.restoreMultiLayerNetwork instead");
            }

            //Handle legacy config - no network DataType in config, in beta3 or earlier
            if (params != null)
                confFromJson.setDataType(params.dataType());

            ComputationGraph cg = new ComputationGraph(confFromJson);
            cg.init(params, false);

            if (gotUpdaterState && updaterState != null) {
                cg.getUpdater().setStateViewArray(updaterState);
            }
            return cg;
        } else
            throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
                    + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
    }

    /**
     *
     * @param model
     * @return
     */
    public static Task taskByModel(Model model) {
        Task task = new Task();
        try {
            task.setArchitectureType(Task.ArchitectureType.RECURRENT);
            if (model instanceof ComputationGraph) {
                task.setNetworkType(Task.NetworkType.ComputationalGraph);
                ComputationGraph network = (ComputationGraph) model;
                try {
                    if (network.getLayers() != null && network.getLayers().length > 0) {
                        for (Layer layer : network.getLayers()) {
                            if (layer.type().equals(Layer.Type.CONVOLUTIONAL)) {
                                task.setArchitectureType(Task.ArchitectureType.CONVOLUTION);
                                break;
                            } else if (layer.type().equals(Layer.Type.RECURRENT)
                                    || layer.type().equals(Layer.Type.RECURSIVE)) {
                                task.setArchitectureType(Task.ArchitectureType.RECURRENT);
                                break;
                            }
                        }
                    } else
                        task.setArchitectureType(Task.ArchitectureType.UNKNOWN);
                } catch (Exception e) {
                    // do nothing here
                }
            } else if (model instanceof MultiLayerNetwork) {
                task.setNetworkType(Task.NetworkType.MultilayerNetwork);
                MultiLayerNetwork network = (MultiLayerNetwork) model;
                try {
                    if (network.getLayers() != null && network.getLayers().length > 0) {
                        for (Layer layer : network.getLayers()) {
                            if (layer.type().equals(Layer.Type.CONVOLUTIONAL)) {
                                task.setArchitectureType(Task.ArchitectureType.CONVOLUTION);
                                break;
                            } else if (layer.type().equals(Layer.Type.RECURRENT)
                                    || layer.type().equals(Layer.Type.RECURSIVE)) {
                                task.setArchitectureType(Task.ArchitectureType.RECURRENT);
                                break;
                            }
                        }
                    } else
                        task.setArchitectureType(Task.ArchitectureType.UNKNOWN);
                } catch (Exception e) {
                    // do nothing here
                }
            }
            return task;
        } catch (Exception e) {
            task.setArchitectureType(Task.ArchitectureType.UNKNOWN);
            task.setNetworkType(Task.NetworkType.DenseNetwork);
            return task;
        }
    }

    /**
     * This method appends normalizer to a given persisted model.
     *
     * PLEASE NOTE: File should be model file saved earlier with ModelSerializer
     *
     * @param f
     * @param normalizer
     */
    public static void addNormalizerToModel(File f, Normalizer<?> normalizer) {
        File tempFile = null;
        try {
            // copy existing model to temporary file
            tempFile = DL4JFileUtils.createTempFile("dl4jModelSerializerTemp", "bin");
            tempFile.deleteOnExit();
            Files.copy(f, tempFile);
            try (ZipFile zipFile = new ZipFile(tempFile);
                    ZipOutputStream writeFile = new ZipOutputStream(
                            new BufferedOutputStream(new FileOutputStream(f)))) {
                // roll over existing files within model, and copy them one by one
                Enumeration<? extends ZipEntry> entries = zipFile.entries();
                while (entries.hasMoreElements()) {
                    ZipEntry entry = entries.nextElement();

                    // we're NOT copying existing normalizer, if any
                    if (entry.getName().equalsIgnoreCase(NORMALIZER_BIN))
                        continue;

                    log.debug("Copying: {}", entry.getName());

                    InputStream is = zipFile.getInputStream(entry);

                    ZipEntry wEntry = new ZipEntry(entry.getName());
                    writeFile.putNextEntry(wEntry);

                    IOUtils.copy(is, writeFile);
                }
                // now, add our normalizer as additional entry
                ZipEntry nEntry = new ZipEntry(NORMALIZER_BIN);
                writeFile.putNextEntry(nEntry);

                NormalizerSerializer.getDefault().write(normalizer, writeFile);
            }
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        } finally {
            if (tempFile != null) {
                tempFile.delete();
            }
        }
    }

    /**
     * Add an object to the (already existing) model file using Java Object Serialization. Objects can be restored
     * using {@link #getObjectFromFile(File, String)}
     * @param f   File to add the object to
     * @param key Key to store the object under
     * @param o   Object to store using Java object serialization
     */
    public static void addObjectToFile(@NonNull File f, @NonNull String key, @NonNull Object o) {
        Preconditions.checkState(f.exists(), "File must exist: %s", f);
        Preconditions.checkArgument(
                !(UPDATER_BIN.equalsIgnoreCase(key) || NORMALIZER_BIN.equalsIgnoreCase(key)
                        || CONFIGURATION_JSON.equalsIgnoreCase(key) || COEFFICIENTS_BIN.equalsIgnoreCase(key)
                        || NO_PARAMS_MARKER.equalsIgnoreCase(key) || PREPROCESSOR_BIN.equalsIgnoreCase(key)),
                "Invalid key: Key is reserved for internal use: \"%s\"", key);
        File tempFile = null;
        try {
            // copy existing model to temporary file
            tempFile = DL4JFileUtils.createTempFile("dl4jModelSerializerTemp", "bin");
            Files.copy(f, tempFile);
            f.delete();
            try (ZipFile zipFile = new ZipFile(tempFile);
                    ZipOutputStream writeFile = new ZipOutputStream(
                            new BufferedOutputStream(new FileOutputStream(f)))) {
                // roll over existing files within model, and copy them one by one
                Enumeration<? extends ZipEntry> entries = zipFile.entries();
                while (entries.hasMoreElements()) {
                    ZipEntry entry = entries.nextElement();

                    log.debug("Copying: {}", entry.getName());

                    InputStream is = zipFile.getInputStream(entry);

                    ZipEntry wEntry = new ZipEntry(entry.getName());
                    writeFile.putNextEntry(wEntry);

                    IOUtils.copy(is, writeFile);
                    writeFile.closeEntry();
                    is.close();
                }

                //Add new object:
                ZipEntry entry = new ZipEntry("objects/" + key);
                writeFile.putNextEntry(entry);

                try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
                        ObjectOutputStream oos = new ObjectOutputStream(baos)) {
                    oos.writeObject(o);
                    byte[] bytes = baos.toByteArray();
                    writeFile.write(bytes);
                }
                writeFile.closeEntry();

                writeFile.close();
                zipFile.close();

            }
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        } finally {
            if (tempFile != null) {
                tempFile.delete();
            }
        }
    }

    /**
     * Get an object with the specified key from the model file, that was previously added to the file using
     * {@link #addObjectToFile(File, String, Object)}
     *
     * @param f   model file to add the object to
     * @param key Key for the object
     * @param <T> Type of the object
     * @return The serialized object
     * @see #listObjectsInFile(File)
     */
    public static <T> T getObjectFromFile(@NonNull File f, @NonNull String key) {
        Preconditions.checkState(f.exists(), "File must exist: %s", f);
        Preconditions.checkArgument(
                !(UPDATER_BIN.equalsIgnoreCase(key) || NORMALIZER_BIN.equalsIgnoreCase(key)
                        || CONFIGURATION_JSON.equalsIgnoreCase(key) || COEFFICIENTS_BIN.equalsIgnoreCase(key)
                        || NO_PARAMS_MARKER.equalsIgnoreCase(key) || PREPROCESSOR_BIN.equalsIgnoreCase(key)),
                "Invalid key: Key is reserved for internal use: \"%s\"", key);

        try (ZipFile zipFile = new ZipFile(f)) {
            ZipEntry entry = zipFile.getEntry("objects/" + key);
            if (entry == null) {
                throw new IllegalStateException("No object with key \"" + key + "\" found");
            }

            Object o;
            try (ObjectInputStream ois = new ObjectInputStream(
                    new BufferedInputStream(zipFile.getInputStream(entry)))) {
                o = ois.readObject();
            }
            zipFile.close();
            return (T) o;
        } catch (IOException | ClassNotFoundException e) {
            throw new RuntimeException("Error reading object (key = " + key + ") from file " + f, e);
        }
    }

    /**
     * List the keys of all objects added using the method {@link #addObjectToFile(File, String, Object)}
     * @param f File previously created with ModelSerializer
     * @return List of keys that can be used with {@link #getObjectFromFile(File, String)}
     */
    public static List<String> listObjectsInFile(@NonNull File f) {
        Preconditions.checkState(f.exists(), "File must exist: %s", f);

        List<String> out = new ArrayList<>();
        try (ZipFile zipFile = new ZipFile(f)) {

            Enumeration<? extends ZipEntry> entries = zipFile.entries();
            while (entries.hasMoreElements()) {
                ZipEntry e = entries.nextElement();
                String name = e.getName();
                if (!e.isDirectory() && name.startsWith("objects/")) {
                    String s = name.substring(8);
                    out.add(s);
                }
            }
            return out;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * This method restores normalizer from a given persisted model file
     *
     * PLEASE NOTE: File should be model file saved earlier with ModelSerializer with addNormalizerToModel being called
     *
     * @param file
     * @return
     */
    public static <T extends Normalizer> T restoreNormalizerFromFile(File file) {
        try (ZipFile zipFile = new ZipFile(file)) {
            ZipEntry norm = zipFile.getEntry(NORMALIZER_BIN);

            // checking for file existence
            if (norm == null)
                return null;

            return NormalizerSerializer.getDefault().restore(zipFile.getInputStream(norm));
        } catch (Exception e) {
            log.warn("Error while restoring normalizer, trying to restore assuming deprecated format...");
            DataNormalization restoredDeprecated = restoreNormalizerFromFileDeprecated(file);

            log.warn("Recovered using deprecated method. Will now re-save the normalizer to fix this issue.");
            addNormalizerToModel(file, restoredDeprecated);

            return (T) restoredDeprecated;
        }
    }

    /**
     * This method restores the normalizer form a persisted model file.
     *
     * @param is A stream to load data from.
     * @return the loaded normalizer
     */
    public static <T extends Normalizer> T restoreNormalizerFromInputStream(InputStream is) throws IOException {
        checkInputStream(is);

        File tmpFile = null;
        try {
            tmpFile = tempFileFromStream(is);
            return restoreNormalizerFromFile(tmpFile);
        } finally {
            if (tmpFile != null) {
                tmpFile.delete();
            }
        }
    }

    /**
     * @deprecated
     *
     * This method restores normalizer from a given persisted model file serialized with Java object serialization
     *
     * @param file
     * @return
     */
    private static DataNormalization restoreNormalizerFromFileDeprecated(File file) {
        try (ZipFile zipFile = new ZipFile(file)) {
            ZipEntry norm = zipFile.getEntry(NORMALIZER_BIN);

            // checking for file existence
            if (norm == null)
                return null;

            InputStream stream = zipFile.getInputStream(norm);
            ObjectInputStream ois = new ObjectInputStream(stream);

            try {
                DataNormalization normalizer = (DataNormalization) ois.readObject();
                return normalizer;
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static void checkInputStream(InputStream inputStream) throws IOException {

        /*
        //available method can return 0 in some cases: https://github.com/deeplearning4j/deeplearning4j/issues/4887
        int available;
        try{
        //InputStream.available(): A subclass' implementation of this method may choose to throw an IOException
        // if this input stream has been closed by invoking the close() method.
        available = inputStream.available();
        } catch (IOException e){
        throw new IOException("Cannot read from stream: stream may have been closed or is attempting to be read from" +
                "multiple times?", e);
        }
        if(available <= 0){
        throw new IOException("Cannot read from stream: stream may have been closed or is attempting to be read from" +
                "multiple times?");
        }
        */
    }

    private static void checkTempFileFromInputStream(File f) throws IOException {
        if (f.length() <= 0) {
            throw new IOException(
                    "Error reading from input stream: temporary file is empty after copying entire stream."
                            + " Stream may have been closed before reading, is attempting to be used multiple times, or does not"
                            + " point to a model file?");
        }
    }

    private static File tempFileFromStream(InputStream is) throws IOException {
        checkInputStream(is);
        String p = System.getProperty(DL4JSystemProperties.DL4J_TEMP_DIR_PROPERTY);
        File tmpFile = DL4JFileUtils.createTempFile("dl4jModelSerializer", "bin");
        try {
            tmpFile.deleteOnExit();
            BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(tmpFile));
            IOUtils.copy(is, bufferedOutputStream);
            bufferedOutputStream.flush();
            IOUtils.closeQuietly(bufferedOutputStream);
            checkTempFileFromInputStream(tmpFile);
            return tmpFile;
        } catch (IOException e) {
            if (tmpFile != null) {
                tmpFile.delete();
            }
            throw e;
        }
    }
}