Example usage for org.deeplearning4j.nn.multilayer MultiLayerNetwork params

List of usage examples for org.deeplearning4j.nn.multilayer MultiLayerNetwork params

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.multilayer MultiLayerNetwork params.

Prototype

@Override
public INDArray params() 

Source Link

Document

Returns a 1 x m vector where the vector is composed of a flattened vector of all of the parameters in the network.
See #getParam(String) and #paramTable() for a more useful/interpretable representation of the parameters.
Note that the parameter vector is not a copy, and changes to the returned INDArray will impact the network parameters.

Usage

From source file:examples.cnn.NetworkTrainer.java

License:Apache License

public void train(JavaRDD<DataSet> train, JavaRDD<DataSet> test) {

    int batchSize = 12 * cores;
    int lrCount = 0;
    double bestAccuracy = Double.MIN_VALUE;

    double learningRate = initialLearningRate;

    int trainCount = Long.valueOf(train.count()).intValue();
    log.info("Number of training images {}", trainCount);
    log.info("Number of test images {}", test.count());

    MultiLayerNetwork net = new MultiLayerNetwork(
            model.apply(learningRate, width, height, channels, numLabels));
    net.init();//from   ww w.j a va 2s  . c  o  m

    Map<Integer, Double> acc = new HashMap<>();
    for (int i = 0; i < epochs; i++) {

        SparkDl4jMultiLayer sparkNetwork = networkToSparkNetwork.apply(net);
        final MultiLayerNetwork nn = sparkNetwork.fitDataSet(train, batchSize, trainCount, cores);
        log.info("Epoch {} completed", i);

        JavaPairRDD<Object, Object> predictionsAndLabels = test.mapToPair(
                ds -> new Tuple2<>(label(nn.output(ds.getFeatureMatrix(), false)), label(ds.getLabels())));
        MulticlassMetrics metrics = new MulticlassMetrics(predictionsAndLabels.rdd());
        double accuracy = 1.0 * predictionsAndLabels.filter(x -> x._1.equals(x._2)).count() / test.count();
        log.info("Epoch {} accuracy {} ", i, accuracy);
        acc.put(i, accuracy);
        predictionsAndLabels.take(10).forEach(t -> log.info("predicted {}, label {}", t._1, t._2));
        log.info("confusionMatrix {}", metrics.confusionMatrix());

        INDArray params = nn.params();
        if (accuracy > bestAccuracy) {
            bestAccuracy = accuracy;
            try {
                ModelSerializer.writeModel(nn, new File(workingDir, Double.toString(accuracy)), false);
            } catch (IOException e) {
                log.error("Error writing trained model", e);
            }
            lrCount = 0;
        } else {

            if (++lrCount % stepDecayTreshold == 0) {
                learningRate *= learningRateDecayFactor;
            }
            if (lrCount >= resetLearningRateThreshold) {
                lrCount = 0;
                learningRate = initialLearningRate;
            }
            if (learningRate < minimumLearningRate) {
                lrCount = 0;
                learningRate = initialLearningRate;
            }
            if (bestAccuracy - accuracy > downgradeAccuracyThreshold) {
                params = ModelLoader.load(workingDir, bestAccuracy);
            }
        }
        net = new MultiLayerNetwork(model.apply(learningRate, width, height, channels, numLabels));
        net.init();
        net.setParameters(params);
        log.info("Learning rate {} for epoch {}", learningRate, i + 1);
    }
    log.info("Training completed");

}

From source file:stratego.neural.net.StrategoNeuralNet.java

/**
 * @param args the command line arguments
 *///  w ww.  j  av  a2 s  .  c  om
public static void main(String[] args) throws Exception {

    /********************************************
    INPUT DATASETS HERE        
    ********************************************/

    String data1 = "src/Data/dataPoint_201.csv"; // location of first dataset
    int labelIndex1 = 12; // label index, the place in a single line where the label is (a label is the correct classification that belongs to the datapoint)
    double ratio1 = 0.9; // the ratio of the data to be used as training (remainder is test)
    int batchSize1 = 35; // sets the size of the micro-batch 
    int numClasses1 = 7; // sets the total number of classifications possible.

    /***************************************
     * LOADING OF DATASETS HAPPENS HERE
     ***************************************/
    DataSet dataSet1 = readCSVDataset(data1, batchSize1, labelIndex1, numClasses1); // storing the raw CSV data in a DataSet object

    /**************************************
     * BUILDING NETWORKS HAPPENS HERE
     **************************************/

    // FIRST NETWORK
    int numInput1 = 12; // sets the number of input neurons (always set equal to the amount of variables for training in a datapoint
    int numHidden1 = 50; // set the number of neurons in the hidden layer
    int iterations1 = 10; // sets the number of iterations to be performed during each epoch
    int scoreListener1 = 1; // sets after how many iterations the score should be listed on the output terminal

    //Note, no need to set numOutput here if we have set numClasses for the dataset, since these are the same

    int numEpochs1 = 50; // sets the amount of epochs to run the training for

    String name1 = "One Layer, 201 datapoints, batchsize " + batchSize1 + " ratio " + ratio1 + " epochs: "
            + numEpochs1; // setting the name for identication
    //SECOND NETWORK
    // Right now I'm interested in difference in performance, so I'm just going to copy all the stats from te first and only change the name
    String name2 = "Two Layer, 201 datapoints, batchsize " + batchSize1 + " ratio " + ratio1 + " epochs: "
            + numEpochs1;

    //THIRD NETWORK
    String name3 = "Three Layer, 201 datapoints, batchsize " + batchSize1 + " ratio " + ratio1 + " epochs: "
            + numEpochs1;
    /*
            OneLayerNetwork oneLayerNetwork = new OneLayerNetwork(numInput1, numHidden1, numClasses1, iterations1, scoreListener1, name1);
            List<NamedDataSet> plotData1 = oneLayerNetwork.train(dataSet1, ratio1, numEpochs1); // trains the network and returns a List containing overfitting data for the plot.
            plotDataSet(plotData1, oneLayerNetwork.getName());
            */
    TwoLayerNetwork twoLayerNetwork = new TwoLayerNetwork(numInput1, numHidden1, numClasses1, iterations1,
            scoreListener1, name2);
    List<NamedDataSet> plotData2 = twoLayerNetwork.train(dataSet1, ratio1, numEpochs1);
    plotDataSet(plotData2, twoLayerNetwork.getName());
    /*
    ThreeLayerNetwork threeLayerNetwork = new ThreeLayerNetwork(numInput1, numHidden1, numClasses1, iterations1, scoreListener1, name3);
    List<NamedDataSet> plotData3 = threeLayerNetwork.train(dataSet1, ratio1, numEpochs1);
    plotDataSet(plotData3, threeLayerNetwork.getName());
    */

    /***********************************************************
     * CONSOLE OUTPUT HAPPENS HERE
     ****************************************************/
    //oneLayerNetwork.evaluation();
    twoLayerNetwork.evaluation();
    // threeLayerNetwork.evaluation();

    twoLayerNetwork.storeNetwork("network");

    MultiLayerNetwork twoLayerTest = ModelSerializer.restoreMultiLayerNetwork("src/NetworkFiles/network.zip");

    System.out.println("Original and restored networks: configs are equal: " + twoLayerNetwork.getNetwork()
            .getLayerWiseConfigurations().equals(twoLayerTest.getLayerWiseConfigurations()));
    System.out.println("Original and restored networks: parameters are equal: "
            + twoLayerNetwork.getNetwork().params().equals(twoLayerTest.params()));

}