Example usage for org.deeplearning4j.nn.multilayer MultiLayerNetwork clear

List of usage examples for org.deeplearning4j.nn.multilayer MultiLayerNetwork clear

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.multilayer MultiLayerNetwork clear.

Prototype

public void clear() 

Source Link

Document

Clear the inputs.

Usage

From source file:org.ensor.fftmusings.rnn.qft.SampleLSTM.java

public static MultiLayerNetwork load(File modelFilename) throws IOException {
    MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFilename);
    net.clear();
    net.setListeners(new ScoreIterationListener());
    return net;/*from w  w w  .j a v  a 2  s .  c om*/
}

From source file:org.ensor.fftmusings.rnn.RNNFactory.java

public static MultiLayerNetwork create(File modelFilename, CharacterIterator iter) throws IOException {

    if (modelFilename.exists()) {
        MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFilename);
        net.clear();
        net.setListeners(new ScoreIterationListener(System.out));
        return net;
    }//w w  w.j av a  2s. co m

    int nOut = iter.totalOutcomes();

    //Set up network configuration:
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).learningRate(0.1)
            .rmsDecay(0.95).seed(12345).regularization(true).l2(0.001).list()
            .layer(0, new GravesLSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize)
                    .updater(Updater.RMSPROP).activation(Activation.TANH).weightInit(WeightInit.DISTRIBUTION)
                    .dist(new UniformDistribution(-0.08, 0.08)).build())
            .layer(1,
                    new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize).updater(Updater.RMSPROP)
                            .activation(Activation.TANH).weightInit(WeightInit.DISTRIBUTION)
                            .dist(new UniformDistribution(-0.08, 0.08)).build())
            .layer(2,
                    new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX) //MCXENT + softmax for classification
                            .updater(Updater.RMSPROP).nIn(lstmLayerSize).nOut(nOut)
                            .weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(-0.08, 0.08))
                            .build())
            .pretrain(false).backprop(true).backpropType(BackpropType.TruncatedBPTT).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    net.setListeners(new ScoreIterationListener(System.out));

    ModelSerializer.writeModel(net, modelFilename, true);

    return net;
}