Example usage for org.deeplearning4j.nn.conf Updater ADADELTA

List of usage examples for org.deeplearning4j.nn.conf Updater ADADELTA

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.conf Updater ADADELTA.

Prototype

Updater ADADELTA

To view the source code for org.deeplearning4j.nn.conf Updater ADADELTA.

Click Source Link

Usage

From source file:gr.aueb.cs.nlp.computationgraphs.GraphConfigurations.java

License:Open Source License

/**
 * an example LSTM Graph//from   ww w.  j a  v  a2s.  co  m
 * @param totalCategories
 * @param trainSet
 * @param testSet
 * @return
 */
public static ComputationGraphConfiguration LSTMGraph(List<Word> trainSet) {
    int inputFeatures = trainSet.get(0).getFeatureVec().getValues().length;
    int outputLabels = trainSet.get(0).getFeatureVec().getLabels().length;

    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(0.01)
            .regularization(true).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .graphBuilder()

            .addInputs("input") //can use any label for this, it is just an identifier for the graph.
            .addLayer("L1", new GravesLSTM.Builder().nIn(inputFeatures) // always current nIn = Sum(nOut) of layer inputs
                    .nOut(150) //you have as many neurons as your outputs...
                    .biasLearningRate(0.2) //faster changing bias int he first layer then slower
                    .activation(Activation.TANH) // the enumerations are made by me, 
                    //to help me know which activations are available,
                    //for me relu work so often but they need more neurons per layer than others
                    .l1(0.3) //l1 regularization
                    .l2(0.02) //l2 regularization
                    .dropOut(0.3) //regularization via freezing whole neurons on a feedforward phase, 
                    //this works usual/y better than l1,l2 for me
                    .updater(Updater.ADADELTA) //how the weights are updated. nesterovs uses the momentum
                    .momentum(0.3) //the higher the easier to escape a saddle point or miss an optimum
                    .build(), "input")//input here, means which layer is the input to this layes. 
            //so the layer with identifier "input" is the input for "L1" 
            .addLayer("L2", new GravesLSTM.Builder().nIn(150) // an autoencoder for some feature extraction
                    .nOut(200).biasLearningRate(0.02).l1(0.3).l2(0.02).activation(Activation.RELU).dropOut(0.3)
                    .updater(Updater.ADAM)//how to use ADAM, read the ADAM paper 
                    //to understand better what they do https://arxiv.org/pdf/1412.6980.pdf
                    //usually cheaper training and better bias correction...
                    .adamMeanDecay(0.2).adamMeanDecay(0.2).build(), "L1")
            .addLayer("L3", new RnnOutputLayer.Builder().nIn(200).nOut(outputLabels).build(), "L2")
            .setOutputs("L3") //We need to specify the network outputs and their order
            .build();
    return conf;
}

From source file:org.wso2.carbon.ml.rest.api.neuralNetworks.FeedForwardNetwork.java

License:Open Source License

/**
 * method to map user selected Updater Algorithm to Updater object.
 * @param updater//from   w w w  .j  a  v  a2 s.  co  m
 * @return an Updater object .
 */
Updater mapUpdater(String updater) {

    Updater updaterAlgo = null;

    switch (updater) {

    case "sgd":
        updaterAlgo = Updater.SGD;
        break;

    case "adam":
        updaterAlgo = Updater.ADAM;
        break;

    case "adadelta":
        updaterAlgo = Updater.ADADELTA;
        break;

    case "nesterovs":
        updaterAlgo = Updater.NESTEROVS;
        break;

    case "adagrad":
        updaterAlgo = Updater.ADAGRAD;
        break;

    case "rmsprop":
        updaterAlgo = Updater.RMSPROP;
        break;

    case "none":
        updaterAlgo = Updater.NONE;
        break;

    case "custom":
        updaterAlgo = Updater.CUSTOM;
        break;

    default:
        updaterAlgo = null;
        break;
    }
    return updaterAlgo;
}