List of usage examples for org.deeplearning4j.nn.multilayer MultiLayerNetwork init
public void init(INDArray parameters, boolean cloneParametersArray)
From source file:org.audiveris.omr.classifier.ModelSystemSerializer.java
License:Open Source License
/** * Load a multi layer network from a file system * * @param root the root path of file system * @param loadUpdater true for loading updater * @return the loaded multi layer network * @throws IOException//from ww w . j a v a2s . co m */ public static MultiLayerNetwork restoreMultiLayerNetwork(Path root, boolean loadUpdater) throws IOException { boolean gotConfig = false; boolean gotCoefficients = false; boolean gotOldUpdater = false; boolean gotUpdaterState = false; boolean gotPreProcessor = false; String json = ""; INDArray params = null; Updater updater = null; INDArray updaterState = null; DataSetPreProcessor preProcessor = null; final Path config = root.resolve("configuration.json"); if (Files.exists(config)) { //restoring configuration InputStream stream = Files.newInputStream(config); BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); String line = ""; StringBuilder js = new StringBuilder(); while ((line = reader.readLine()) != null) { js.append(line).append("\n"); } json = js.toString(); reader.close(); stream.close(); gotConfig = true; } final Path coefficients = root.resolve("coefficients.bin"); if (Files.exists(coefficients)) { InputStream stream = Files.newInputStream(coefficients); DataInputStream dis = new DataInputStream(new BufferedInputStream(stream)); params = Nd4j.read(dis); dis.close(); gotCoefficients = true; } if (loadUpdater) { //This can be removed a few releases after 0.4.1... final Path oldUpdaters = root.resolve(OLD_UPDATER_BIN); if (Files.exists(oldUpdaters)) { InputStream stream = Files.newInputStream(oldUpdaters); ObjectInputStream ois = new ObjectInputStream(stream); try { updater = (Updater) ois.readObject(); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } gotOldUpdater = true; } final Path updaterStateEntry = root.resolve(UPDATER_BIN); if (updaterStateEntry != null) { InputStream stream = Files.newInputStream(updaterStateEntry); DataInputStream dis = new DataInputStream(stream); updaterState = Nd4j.read(dis); dis.close(); gotUpdaterState = true; } } final Path prep = root.resolve("preprocessor.bin"); if (Files.exists(prep)) { InputStream stream = Files.newInputStream(prep); ObjectInputStream ois = new ObjectInputStream(stream); try { preProcessor = (DataSetPreProcessor) ois.readObject(); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } gotPreProcessor = true; } if (gotConfig && gotCoefficients) { MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); MultiLayerNetwork network = new MultiLayerNetwork(confFromJson); network.init(params, false); if (gotUpdaterState && (updaterState != null)) { network.getUpdater().setStateViewArray(network, updaterState, false); } else if (gotOldUpdater && (updater != null)) { network.setUpdater(updater); } return network; } else { throw new IllegalStateException("Model wasn't found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]"); } }