List of usage examples for org.deeplearning4j.nn.multilayer MultiLayerNetwork setUpdater
public void setUpdater(Updater updater)
From source file:org.audiveris.omr.classifier.ModelSystemSerializer.java
License:Open Source License
/** * Load a multi layer network from a file system * * @param root the root path of file system * @param loadUpdater true for loading updater * @return the loaded multi layer network * @throws IOException/*from w w w .ja v a2 s . c o m*/ */ public static MultiLayerNetwork restoreMultiLayerNetwork(Path root, boolean loadUpdater) throws IOException { boolean gotConfig = false; boolean gotCoefficients = false; boolean gotOldUpdater = false; boolean gotUpdaterState = false; boolean gotPreProcessor = false; String json = ""; INDArray params = null; Updater updater = null; INDArray updaterState = null; DataSetPreProcessor preProcessor = null; final Path config = root.resolve("configuration.json"); if (Files.exists(config)) { //restoring configuration InputStream stream = Files.newInputStream(config); BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); String line = ""; StringBuilder js = new StringBuilder(); while ((line = reader.readLine()) != null) { js.append(line).append("\n"); } json = js.toString(); reader.close(); stream.close(); gotConfig = true; } final Path coefficients = root.resolve("coefficients.bin"); if (Files.exists(coefficients)) { InputStream stream = Files.newInputStream(coefficients); DataInputStream dis = new DataInputStream(new BufferedInputStream(stream)); params = Nd4j.read(dis); dis.close(); gotCoefficients = true; } if (loadUpdater) { //This can be removed a few releases after 0.4.1... final Path oldUpdaters = root.resolve(OLD_UPDATER_BIN); if (Files.exists(oldUpdaters)) { InputStream stream = Files.newInputStream(oldUpdaters); ObjectInputStream ois = new ObjectInputStream(stream); try { updater = (Updater) ois.readObject(); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } gotOldUpdater = true; } final Path updaterStateEntry = root.resolve(UPDATER_BIN); if (updaterStateEntry != null) { InputStream stream = Files.newInputStream(updaterStateEntry); DataInputStream dis = new DataInputStream(stream); updaterState = Nd4j.read(dis); dis.close(); gotUpdaterState = true; } } final Path prep = root.resolve("preprocessor.bin"); if (Files.exists(prep)) { InputStream stream = Files.newInputStream(prep); ObjectInputStream ois = new ObjectInputStream(stream); try { preProcessor = (DataSetPreProcessor) ois.readObject(); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } gotPreProcessor = true; } if (gotConfig && gotCoefficients) { MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); MultiLayerNetwork network = new MultiLayerNetwork(confFromJson); network.init(params, false); if (gotUpdaterState && (updaterState != null)) { network.getUpdater().setStateViewArray(network, updaterState, false); } else if (gotOldUpdater && (updater != null)) { network.setUpdater(updater); } return network; } else { throw new IllegalStateException("Model wasn't found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]"); } }
From source file:org.knime.ext.dl4j.base.nodes.learn.AbstractDLLearnerNodeModel.java
License:Open Source License
/** * Attempts to transfer the {@link Updater} from one {@link MultiLayerNetwork} to another. This is important if you * want to further train a pretrained net as some {@link Updater}s contain a history of gradients from training. * * @param from the network to get the updater from * @param to the network to transfer the updater to *//*from www . java 2 s . c o m*/ protected void transferUpdater(final MultiLayerNetwork from, final MultiLayerNetwork to) { final org.deeplearning4j.nn.api.Updater updater = from.getUpdater(); if (updater == null) { logger.warn("Could not transfer updater between nets as there is no updater set in the source net"); } else { to.setUpdater(updater); logger.info("Successfully transfered updater between nets."); } }
From source file:org.knime.ext.dl4j.base.util.DLModelPortObjectUtils.java
License:Open Source License
/** * Creates a {@link MultiLayerNetwork} from deserialized objects in the old format. This is now done implicitly by * the dl4j {@link ModelSerializer}.//from w w w . ja v a2 s .c o m * * @param config * @param updater * @param params * @return */ @Deprecated private static MultiLayerNetwork buildMln(final MultiLayerConfiguration config, final org.deeplearning4j.nn.api.Updater updater, final INDArray params) { MultiLayerNetwork mln = null; if (config != null) { mln = new MultiLayerNetwork(config); mln.init(); if (updater != null) { mln.setUpdater(updater); } if (params != null) { mln.setParams(params); } } return mln; }
From source file:vectorizer.ModelSerializer.java
public static MultiLayerNetwork restoreMultiLayerNetwork(File file) throws IOException { ZipFile zipFile = new ZipFile(file); boolean gotConfig = false; boolean gotCoefficients = false; boolean gotUpdater = false; String json = ""; INDArray params = null;/*from w w w .ja v a 2 s.com*/ Updater updater = null; ZipEntry config = zipFile.getEntry("configuration.json"); if (config != null) { //restoring configuration InputStream stream = zipFile.getInputStream(config); BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); String line = ""; StringBuilder js = new StringBuilder(); while ((line = reader.readLine()) != null) { js.append(line).append("\n"); } json = js.toString(); reader.close(); stream.close(); gotConfig = true; } ZipEntry coefficients = zipFile.getEntry("coefficients.bin"); if (coefficients != null) { InputStream stream = zipFile.getInputStream(coefficients); DataInputStream dis = new DataInputStream(stream); params = Nd4j.read(dis); dis.close(); gotCoefficients = true; } ZipEntry updaters = zipFile.getEntry("updater.bin"); if (updaters != null) { InputStream stream = zipFile.getInputStream(updaters); ObjectInputStream ois = new ObjectInputStream(stream); try { updater = (Updater) ois.readObject(); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } gotUpdater = true; } zipFile.close(); if (gotConfig && gotCoefficients) { MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); MultiLayerNetwork network = new MultiLayerNetwork(confFromJson); network.init(); network.setParameters(params); if (gotUpdater && updater != null) { network.setUpdater(updater); } return network; } else throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdater + "]"); }