List of usage examples for org.deeplearning4j.nn.conf MultiLayerConfiguration fromJson
public static MultiLayerConfiguration fromJson(String json)
From source file:org.audiveris.omr.classifier.ModelSystemSerializer.java
License:Open Source License
/** * Load a multi layer network from a file system * * @param root the root path of file system * @param loadUpdater true for loading updater * @return the loaded multi layer network * @throws IOException// w w w . j ava2 s. c o m */ public static MultiLayerNetwork restoreMultiLayerNetwork(Path root, boolean loadUpdater) throws IOException { boolean gotConfig = false; boolean gotCoefficients = false; boolean gotOldUpdater = false; boolean gotUpdaterState = false; boolean gotPreProcessor = false; String json = ""; INDArray params = null; Updater updater = null; INDArray updaterState = null; DataSetPreProcessor preProcessor = null; final Path config = root.resolve("configuration.json"); if (Files.exists(config)) { //restoring configuration InputStream stream = Files.newInputStream(config); BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); String line = ""; StringBuilder js = new StringBuilder(); while ((line = reader.readLine()) != null) { js.append(line).append("\n"); } json = js.toString(); reader.close(); stream.close(); gotConfig = true; } final Path coefficients = root.resolve("coefficients.bin"); if (Files.exists(coefficients)) { InputStream stream = Files.newInputStream(coefficients); DataInputStream dis = new DataInputStream(new BufferedInputStream(stream)); params = Nd4j.read(dis); dis.close(); gotCoefficients = true; } if (loadUpdater) { //This can be removed a few releases after 0.4.1... final Path oldUpdaters = root.resolve(OLD_UPDATER_BIN); if (Files.exists(oldUpdaters)) { InputStream stream = Files.newInputStream(oldUpdaters); ObjectInputStream ois = new ObjectInputStream(stream); try { updater = (Updater) ois.readObject(); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } gotOldUpdater = true; } final Path updaterStateEntry = root.resolve(UPDATER_BIN); if (updaterStateEntry != null) { InputStream stream = Files.newInputStream(updaterStateEntry); DataInputStream dis = new DataInputStream(stream); updaterState = Nd4j.read(dis); dis.close(); gotUpdaterState = true; } } final Path prep = root.resolve("preprocessor.bin"); if (Files.exists(prep)) { InputStream stream = Files.newInputStream(prep); ObjectInputStream ois = new ObjectInputStream(stream); try { preProcessor = (DataSetPreProcessor) ois.readObject(); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } gotPreProcessor = true; } if (gotConfig && gotCoefficients) { MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); MultiLayerNetwork network = new MultiLayerNetwork(confFromJson); network.init(params, false); if (gotUpdaterState && (updaterState != null)) { network.getUpdater().setStateViewArray(network, updaterState, false); } else if (gotOldUpdater && (updater != null)) { network.setUpdater(updater); } return network; } else { throw new IllegalStateException("Model wasn't found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]"); } }
From source file:org.knime.ext.dl4j.base.util.DLModelPortObjectUtils.java
License:Open Source License
/** * Loads a {@link DLModelPortObject} from the specified {@link ZipInputStream}. Supports both deserialization of old * and new format./*from w w w. ja v a 2 s. c o m*/ * * @param inStream the stream to load from * @return the loaded {@link DLModelPortObject} * @throws IOException */ @SuppressWarnings("resource") public static DLModelPortObject loadPortFromZip(final ZipInputStream inStream) throws IOException { final List<Layer> layers = new ArrayList<>(); //old model format INDArray mln_params = null; MultiLayerConfiguration mln_config = null; org.deeplearning4j.nn.api.Updater updater = null; //new model format boolean mlnLoaded = false; boolean cgLoaded = false; MultiLayerNetwork mlnFromModelSerializer = null; ComputationGraph cgFromModelSerializer = null; ZipEntry entry; while ((entry = inStream.getNextEntry()) != null) { // read layers if (entry.getName().matches("layer[0123456789]+")) { final String read = readStringFromZipStream(inStream); Layer l = NeuralNetConfiguration.fromJson(read).getLayer(); if (l instanceof BaseLayer) { BaseLayer bl = (BaseLayer) l; /* Compatibility issue between dl4j 0.6 and 0.8 due to API change. Activations changed from * Strings to an interface. Therefore, if a model was saved with 0.6 the corresponding member * of the layer object will contain null after 'NeuralNetConfiguration.fromJson'. Old method to * retrieve String representation of the activation function was removed. Therefore, we parse * the old activation from the json ourself and map it to the new Activation. */ if (bl.getActivationFn() == null) { Optional<Activation> layerActivation = DL4JVersionUtils.parseLayerActivationFromJson(read); if (layerActivation.isPresent()) { bl.setActivationFn(layerActivation.get().getActivationFunction()); } } } layers.add(l); // directly read MultiLayerNetwork, new format } else if (entry.getName().matches("mln_model")) { //stream must not be closed, ModelSerializer tries to close the stream CloseShieldInputStream shieldIs = new CloseShieldInputStream(inStream); mlnFromModelSerializer = ModelSerializer.restoreMultiLayerNetwork(shieldIs, true); mlnLoaded = true; // directly read MultiLayerNetwork, new format } else if (entry.getName().matches("cg_model")) { //stream must not be closed, ModelSerializer tries to close the stream CloseShieldInputStream shieldIs = new CloseShieldInputStream(inStream); cgFromModelSerializer = ModelSerializer.restoreComputationGraph(shieldIs, true); cgLoaded = true; // read MultilayerNetworkConfig, old format } else if (entry.getName().matches("mln_config")) { final String read = readStringFromZipStream(inStream); mln_config = MultiLayerConfiguration.fromJson(read.toString()); // read params, old format } else if (entry.getName().matches("mln_params")) { try { mln_params = Nd4j.read(inStream); } catch (Exception e) { throw new IOException("Could not load network parameters. Please re-execute the Node.", e); } // read updater, old format } else if (entry.getName().matches("mln_updater")) { // stream must not be closed, even if an exception is thrown, because the wrapped stream must stay open final IgnoreIDObjectInputStream ois = new IgnoreIDObjectInputStream(inStream); try { updater = (org.deeplearning4j.nn.api.Updater) ois.readObject(); } catch (final ClassNotFoundException e) { throw new IOException("Problem with updater loading: " + e.getMessage(), e); } } } if (mlnLoaded) { assert (!cgLoaded); return new DLModelPortObject(layers, mlnFromModelSerializer, null); } else if (cgLoaded) { assert (!mlnLoaded); return new DLModelPortObject(layers, cgFromModelSerializer, null); } else { return new DLModelPortObject(layers, buildMln(mln_config, updater, mln_params), null); } }
From source file:vectorizer.ModelSerializer.java
public static MultiLayerNetwork restoreMultiLayerNetwork(File file) throws IOException { ZipFile zipFile = new ZipFile(file); boolean gotConfig = false; boolean gotCoefficients = false; boolean gotUpdater = false; String json = ""; INDArray params = null;// w ww. j ava 2s . c om Updater updater = null; ZipEntry config = zipFile.getEntry("configuration.json"); if (config != null) { //restoring configuration InputStream stream = zipFile.getInputStream(config); BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); String line = ""; StringBuilder js = new StringBuilder(); while ((line = reader.readLine()) != null) { js.append(line).append("\n"); } json = js.toString(); reader.close(); stream.close(); gotConfig = true; } ZipEntry coefficients = zipFile.getEntry("coefficients.bin"); if (coefficients != null) { InputStream stream = zipFile.getInputStream(coefficients); DataInputStream dis = new DataInputStream(stream); params = Nd4j.read(dis); dis.close(); gotCoefficients = true; } ZipEntry updaters = zipFile.getEntry("updater.bin"); if (updaters != null) { InputStream stream = zipFile.getInputStream(updaters); ObjectInputStream ois = new ObjectInputStream(stream); try { updater = (Updater) ois.readObject(); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } gotUpdater = true; } zipFile.close(); if (gotConfig && gotCoefficients) { MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); MultiLayerNetwork network = new MultiLayerNetwork(confFromJson); network.init(); network.setParameters(params); if (gotUpdater && updater != null) { network.setUpdater(updater); } return network; } else throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdater + "]"); }