List of usage examples for org.deeplearning4j.nn.conf NeuralNetConfiguration fromJson
public static NeuralNetConfiguration fromJson(String json)
From source file:org.knime.ext.dl4j.base.util.DLModelPortObjectUtils.java
License:Open Source License
/** * Loads a {@link DLModelPortObject} from the specified {@link ZipInputStream}. Supports both deserialization of old * and new format./*w ww .ja v a2 s. co m*/ * * @param inStream the stream to load from * @return the loaded {@link DLModelPortObject} * @throws IOException */ @SuppressWarnings("resource") public static DLModelPortObject loadPortFromZip(final ZipInputStream inStream) throws IOException { final List<Layer> layers = new ArrayList<>(); //old model format INDArray mln_params = null; MultiLayerConfiguration mln_config = null; org.deeplearning4j.nn.api.Updater updater = null; //new model format boolean mlnLoaded = false; boolean cgLoaded = false; MultiLayerNetwork mlnFromModelSerializer = null; ComputationGraph cgFromModelSerializer = null; ZipEntry entry; while ((entry = inStream.getNextEntry()) != null) { // read layers if (entry.getName().matches("layer[0123456789]+")) { final String read = readStringFromZipStream(inStream); Layer l = NeuralNetConfiguration.fromJson(read).getLayer(); if (l instanceof BaseLayer) { BaseLayer bl = (BaseLayer) l; /* Compatibility issue between dl4j 0.6 and 0.8 due to API change. Activations changed from * Strings to an interface. Therefore, if a model was saved with 0.6 the corresponding member * of the layer object will contain null after 'NeuralNetConfiguration.fromJson'. Old method to * retrieve String representation of the activation function was removed. Therefore, we parse * the old activation from the json ourself and map it to the new Activation. */ if (bl.getActivationFn() == null) { Optional<Activation> layerActivation = DL4JVersionUtils.parseLayerActivationFromJson(read); if (layerActivation.isPresent()) { bl.setActivationFn(layerActivation.get().getActivationFunction()); } } } layers.add(l); // directly read MultiLayerNetwork, new format } else if (entry.getName().matches("mln_model")) { //stream must not be closed, ModelSerializer tries to close the stream CloseShieldInputStream shieldIs = new CloseShieldInputStream(inStream); mlnFromModelSerializer = ModelSerializer.restoreMultiLayerNetwork(shieldIs, true); mlnLoaded = true; // directly read MultiLayerNetwork, new format } else if (entry.getName().matches("cg_model")) { //stream must not be closed, ModelSerializer tries to close the stream CloseShieldInputStream shieldIs = new CloseShieldInputStream(inStream); cgFromModelSerializer = ModelSerializer.restoreComputationGraph(shieldIs, true); cgLoaded = true; // read MultilayerNetworkConfig, old format } else if (entry.getName().matches("mln_config")) { final String read = readStringFromZipStream(inStream); mln_config = MultiLayerConfiguration.fromJson(read.toString()); // read params, old format } else if (entry.getName().matches("mln_params")) { try { mln_params = Nd4j.read(inStream); } catch (Exception e) { throw new IOException("Could not load network parameters. Please re-execute the Node.", e); } // read updater, old format } else if (entry.getName().matches("mln_updater")) { // stream must not be closed, even if an exception is thrown, because the wrapped stream must stay open final IgnoreIDObjectInputStream ois = new IgnoreIDObjectInputStream(inStream); try { updater = (org.deeplearning4j.nn.api.Updater) ois.readObject(); } catch (final ClassNotFoundException e) { throw new IOException("Problem with updater loading: " + e.getMessage(), e); } } } if (mlnLoaded) { assert (!cgLoaded); return new DLModelPortObject(layers, mlnFromModelSerializer, null); } else if (cgLoaded) { assert (!mlnLoaded); return new DLModelPortObject(layers, cgFromModelSerializer, null); } else { return new DLModelPortObject(layers, buildMln(mln_config, updater, mln_params), null); } }