List of usage examples for org.deeplearning4j.nn.multilayer MultiLayerNetwork setParameters
public void setParameters(INDArray params)
From source file:examples.cnn.NetworkTrainer.java
License:Apache License
public void train(JavaRDD<DataSet> train, JavaRDD<DataSet> test) { int batchSize = 12 * cores; int lrCount = 0; double bestAccuracy = Double.MIN_VALUE; double learningRate = initialLearningRate; int trainCount = Long.valueOf(train.count()).intValue(); log.info("Number of training images {}", trainCount); log.info("Number of test images {}", test.count()); MultiLayerNetwork net = new MultiLayerNetwork( model.apply(learningRate, width, height, channels, numLabels)); net.init();//from ww w . j a va2 s . c o m Map<Integer, Double> acc = new HashMap<>(); for (int i = 0; i < epochs; i++) { SparkDl4jMultiLayer sparkNetwork = networkToSparkNetwork.apply(net); final MultiLayerNetwork nn = sparkNetwork.fitDataSet(train, batchSize, trainCount, cores); log.info("Epoch {} completed", i); JavaPairRDD<Object, Object> predictionsAndLabels = test.mapToPair( ds -> new Tuple2<>(label(nn.output(ds.getFeatureMatrix(), false)), label(ds.getLabels()))); MulticlassMetrics metrics = new MulticlassMetrics(predictionsAndLabels.rdd()); double accuracy = 1.0 * predictionsAndLabels.filter(x -> x._1.equals(x._2)).count() / test.count(); log.info("Epoch {} accuracy {} ", i, accuracy); acc.put(i, accuracy); predictionsAndLabels.take(10).forEach(t -> log.info("predicted {}, label {}", t._1, t._2)); log.info("confusionMatrix {}", metrics.confusionMatrix()); INDArray params = nn.params(); if (accuracy > bestAccuracy) { bestAccuracy = accuracy; try { ModelSerializer.writeModel(nn, new File(workingDir, Double.toString(accuracy)), false); } catch (IOException e) { log.error("Error writing trained model", e); } lrCount = 0; } else { if (++lrCount % stepDecayTreshold == 0) { learningRate *= learningRateDecayFactor; } if (lrCount >= resetLearningRateThreshold) { lrCount = 0; learningRate = initialLearningRate; } if (learningRate < minimumLearningRate) { lrCount = 0; learningRate = initialLearningRate; } if (bestAccuracy - accuracy > downgradeAccuracyThreshold) { params = ModelLoader.load(workingDir, bestAccuracy); } } net = new MultiLayerNetwork(model.apply(learningRate, width, height, channels, numLabels)); net.init(); net.setParameters(params); log.info("Learning rate {} for epoch {}", learningRate, i + 1); } log.info("Training completed"); }
From source file:vectorizer.ModelSerializer.java
public static MultiLayerNetwork restoreMultiLayerNetwork(File file) throws IOException { ZipFile zipFile = new ZipFile(file); boolean gotConfig = false; boolean gotCoefficients = false; boolean gotUpdater = false; String json = ""; INDArray params = null;/*from w w w. jav a 2s .c o m*/ Updater updater = null; ZipEntry config = zipFile.getEntry("configuration.json"); if (config != null) { //restoring configuration InputStream stream = zipFile.getInputStream(config); BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); String line = ""; StringBuilder js = new StringBuilder(); while ((line = reader.readLine()) != null) { js.append(line).append("\n"); } json = js.toString(); reader.close(); stream.close(); gotConfig = true; } ZipEntry coefficients = zipFile.getEntry("coefficients.bin"); if (coefficients != null) { InputStream stream = zipFile.getInputStream(coefficients); DataInputStream dis = new DataInputStream(stream); params = Nd4j.read(dis); dis.close(); gotCoefficients = true; } ZipEntry updaters = zipFile.getEntry("updater.bin"); if (updaters != null) { InputStream stream = zipFile.getInputStream(updaters); ObjectInputStream ois = new ObjectInputStream(stream); try { updater = (Updater) ois.readObject(); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } gotUpdater = true; } zipFile.close(); if (gotConfig && gotCoefficients) { MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); MultiLayerNetwork network = new MultiLayerNetwork(confFromJson); network.init(); network.setParameters(params); if (gotUpdater && updater != null) { network.setUpdater(updater); } return network; } else throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdater + "]"); }