Example usage for org.deeplearning4j.nn.weights WeightInit DISTRIBUTION

List of usage examples for org.deeplearning4j.nn.weights WeightInit DISTRIBUTION

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.weights WeightInit DISTRIBUTION.

Prototype

WeightInit DISTRIBUTION

To view the source code for org.deeplearning4j.nn.weights WeightInit DISTRIBUTION.

Click Source Link

Usage

From source file:com.javafxpert.neuralnetviz.scenario.TicTacToe.java

License:Apache License

public static MultiLayerNetworkEnhanced buildNetwork(WebSocketSession webSocketSession) throws Exception {

    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    org.datavec.api.records.reader.RecordReader recordReader = new org.datavec.api.records.reader.impl.csv.CSVRecordReader(
            numLinesToSkip, delimiter);/*from  w ww .j  a va  2s  .c o  m*/
    recordReader.initialize(new org.datavec.api.split.FileSplit(
            new File("src/main/resources/classification/tic_tac_toe_all.csv")));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 0; // 28 values in each row of the dataset:  Labels are the 1st value (index 0) in each row
    int numClasses = 9; //9 classes (a move for X in each square) in the data set. Classes have integer values 0 - 8

    //TODO: Ascertain best batch size for large datasets
    int batchSize = 4520; //Data set: ??? examples total. We are loading all of them into one DataSet (not recommended for large data sets)

    DataSetIterator iterator = new org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator(recordReader,
            batchSize, labelIndex, numClasses);
    DataSet allData = iterator.next();
    allData.shuffle();
    //SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.50);  //Use 75% of data for training

    //DataSet trainingData = testAndTrain.getTrain();
    //DataSet testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    //DataNormalization normalizer = new NormalizerStandardize();
    //normalizer.fit(allData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    //normalizer.transform(allData);     //Apply normalization to the training data
    //normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set

    final int numInputs = 27;
    int outputNum = 9;
    int iterations = 10000;
    long seed = 6;

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations)
            .activation("tanh").weightInit(WeightInit.XAVIER).learningRate(1.9).useDropConnect(false)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).biasInit(0)
            .regularization(true).l2(1e-4).list()
            .layer(0,
                    new DenseLayer.Builder().nIn(numInputs).nOut(54).weightInit(WeightInit.DISTRIBUTION)
                            .activation("sigmoid").build())
            .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).activation("softmax").nIn(54)
                    .nOut(outputNum).build())
            .backprop(true).pretrain(false).build();

    String[] inputFeatureNames = { "a:_", "a: X", "a: O", "b:_", "b: X", "b: O", "c:_", "c: X", "c: O", "d:_",
            "d: X", "d: O", "e:_", "e: X", "e: O", "f:_", "f: X", "f: O", "g:_", "g: X", "g: O", "h:_", "h: X",
            "h: O", "i:_", "i: X", "i: O" };
    String[] outputLabelNames = { "cell a", "cell b", "cell c", "cell d", "cell e", "cell f", "cell g",
            "cell h", "cell i" };
    MultiLayerNetworkEnhanced model = new MultiLayerNetworkEnhanced(conf, inputFeatureNames, outputLabelNames);
    model.init();
    //model.setListeners(new ScoreIterationListener(100));    //Print score every 100 parameter updates
    model.setListeners(new ModelListener(500, webSocketSession));
    //model.setDataNormalization(normalizer);

    model.fit(allData);

    //evaluate the model on the test set
    /*
    Evaluation eval = new Evaluation(outputNum);
    INDArray output = model.output(testData.getFeatureMatrix());
    eval.eval(testData.getLabels(), output);
    System.out.println(eval.stats());
    */

    // Make prediction
    // Input: 0,1,0, 0,0,1, 0,1,0, 1,0,0, 1,0,0, 0,0,1, 1,0,0, 1,0,0, 1,0,0  Expected output: 4
    INDArray example = Nd4j.zeros(1, 27);
    example.putScalar(new int[] { 0, 0 }, 0);
    example.putScalar(new int[] { 0, 1 }, 1);
    example.putScalar(new int[] { 0, 2 }, 0);
    example.putScalar(new int[] { 0, 3 }, 0);
    example.putScalar(new int[] { 0, 4 }, 0);
    example.putScalar(new int[] { 0, 5 }, 1);
    example.putScalar(new int[] { 0, 6 }, 0);
    example.putScalar(new int[] { 0, 7 }, 1);
    example.putScalar(new int[] { 0, 8 }, 0);
    example.putScalar(new int[] { 0, 9 }, 1);
    example.putScalar(new int[] { 0, 10 }, 0);
    example.putScalar(new int[] { 0, 11 }, 0);
    example.putScalar(new int[] { 0, 12 }, 1);
    example.putScalar(new int[] { 0, 13 }, 0);
    example.putScalar(new int[] { 0, 14 }, 0);
    example.putScalar(new int[] { 0, 15 }, 0);
    example.putScalar(new int[] { 0, 16 }, 0);
    example.putScalar(new int[] { 0, 17 }, 1);
    example.putScalar(new int[] { 0, 18 }, 1);
    example.putScalar(new int[] { 0, 19 }, 0);
    example.putScalar(new int[] { 0, 20 }, 0);
    example.putScalar(new int[] { 0, 21 }, 1);
    example.putScalar(new int[] { 0, 22 }, 0);
    example.putScalar(new int[] { 0, 23 }, 0);
    example.putScalar(new int[] { 0, 24 }, 1);
    example.putScalar(new int[] { 0, 25 }, 0);
    example.putScalar(new int[] { 0, 26 }, 0);
    DataSet ds = new DataSet(example, null);
    //normalizer.transform(ds);
    int[] prediction = model.predict(example);
    System.out.println("prediction for ???: " + prediction[0]);

    System.out.println("****************Example finished********************");

    return model;
}

From source file:com.javafxpert.neuralnetviz.scenario.XorExample.java

License:Apache License

public static MultiLayerNetworkEnhanced buildNetwork(WebSocketSession webSocketSession) throws Exception {
    //public static void main(String[] args) throws  Exception {

    //System.out.println("In XorExample.go()");

    // list off input values, 4 training samples with data for 2
    // input-neurons each
    INDArray input = Nd4j.zeros(4, 2);/*  www  .  ja v  a  2  s .co m*/

    //System.out.println("After INDArray input: " + input);

    // correspondending list with expected output values, 4 training samples
    // with data for 2 output-neurons each
    INDArray labels = Nd4j.zeros(4, 2);

    // create first dataset
    // when first input=0 and second input=0
    input.putScalar(new int[] { 0, 0 }, 0);
    input.putScalar(new int[] { 0, 1 }, 0);
    // then the first output fires for false, and the second is 0 (see class
    // comment)
    labels.putScalar(new int[] { 0, 0 }, 1);
    labels.putScalar(new int[] { 0, 1 }, 0);

    // when first input=1 and second input=0
    input.putScalar(new int[] { 1, 0 }, 1);
    input.putScalar(new int[] { 1, 1 }, 0);
    // then xor is true, therefore the second output neuron fires
    labels.putScalar(new int[] { 1, 0 }, 0);
    labels.putScalar(new int[] { 1, 1 }, 1);

    // same as above
    input.putScalar(new int[] { 2, 0 }, 0);
    input.putScalar(new int[] { 2, 1 }, 1);
    labels.putScalar(new int[] { 2, 0 }, 0);
    labels.putScalar(new int[] { 2, 1 }, 1);

    // when both inputs fire, xor is false again - the first output should
    // fire
    input.putScalar(new int[] { 3, 0 }, 1);
    input.putScalar(new int[] { 3, 1 }, 1);
    labels.putScalar(new int[] { 3, 0 }, 1);
    labels.putScalar(new int[] { 3, 1 }, 0);

    //System.out.println("Before DataSet ds");

    // create dataset object
    DataSet ds = new DataSet(input, labels);

    //System.out.println("After DataSet ds: " + ds);

    // Set up network configuration
    NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
    // how often should the training set be run, we need something above
    // 1000, or a higher learning-rate - found this values just by trial and
    // error
    builder.iterations(10000);
    // learning rate
    builder.learningRate(0.1);
    // fixed seed for the random generator, so any run of this program
    // brings the same results - may not work if you do something like
    // ds.shuffle()
    builder.seed(123);
    // not applicable, this network is to small - but for bigger networks it
    // can help that the network will not only recite the training data
    builder.useDropConnect(false);
    // a standard algorithm for moving on the error-plane, this one works
    // best for me, LINE_GRADIENT_DESCENT or CONJUGATE_GRADIENT can do the
    // job, too - it's an empirical value which one matches best to
    // your problem
    builder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
    // init the bias with 0 - empirical value, too
    builder.biasInit(0);
    // from "http://deeplearning4j.org/architecture": The networks can
    // process the input more quickly and more accurately by ingesting
    // minibatches 5-10 elements at a time in parallel.
    // this example runs better without, because the dataset is smaller than
    // the mini batch size
    builder.miniBatch(false);

    // create a multilayer network with 2 layers (including the output
    // layer, excluding the input payer)
    ListBuilder listBuilder = builder.list();

    DenseLayer.Builder hiddenLayerBuilder = new DenseLayer.Builder();
    // two input connections - simultaneously defines the number of input
    // neurons, because it's the first non-input-layer
    hiddenLayerBuilder.nIn(2);
    // number of outgooing connections, nOut simultaneously defines the
    // number of neurons in this layer
    hiddenLayerBuilder.nOut(2);
    // put the output through the sigmoid function, to cap the output
    // valuebetween 0 and 1
    hiddenLayerBuilder.activation("sigmoid");
    // random initialize weights with values between 0 and 1
    hiddenLayerBuilder.weightInit(WeightInit.DISTRIBUTION);
    hiddenLayerBuilder.dist(new UniformDistribution(0, 1));

    // build and set as layer 0
    listBuilder.layer(0, hiddenLayerBuilder.build());

    // MCXENT or NEGATIVELOGLIKELIHOOD work ok for this example - this
    // function calculates the error-value
    // From homepage: Your net's purpose will determine the loss funtion you
    // use. For pretraining, choose reconstruction entropy. For
    // classification, use multiclass cross entropy.
    Builder outputLayerBuilder = new Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD);
    // must be the same amout as neurons in the layer before
    outputLayerBuilder.nIn(2);
    // two neurons in this layer
    outputLayerBuilder.nOut(2);
    outputLayerBuilder.activation("sigmoid");
    outputLayerBuilder.weightInit(WeightInit.DISTRIBUTION);
    outputLayerBuilder.dist(new UniformDistribution(0, 1));
    listBuilder.layer(1, outputLayerBuilder.build());

    // no pretrain phase for this network
    listBuilder.pretrain(false);

    // seems to be mandatory
    // according to agibsonccc: You typically only use that with
    // pretrain(true) when you want to do pretrain/finetune without changing
    // the previous layers finetuned weights that's for autoencoders and
    // rbms
    listBuilder.backprop(true);

    // build and init the network, will check if everything is configured
    // correct
    MultiLayerConfiguration conf = listBuilder.build();

    String[] inputFeatureNames = { "true (1) or false (0)", "true (1) or false (0)" };
    String[] outputLabelNames = { "false", "true" };
    MultiLayerNetworkEnhanced net = new MultiLayerNetworkEnhanced(conf, inputFeatureNames, outputLabelNames);
    net.init();

    // add an listener which outputs the error every 100 parameter updates
    //net.setListeners(new ScoreIterationListener(100));
    net.setListeners(new ModelListener(100, webSocketSession));

    // C&P from GravesLSTMCharModellingExample
    // Print the number of parameters in the network (and for each layer)
    Layer[] layers = net.getLayers();
    int totalNumParams = 0;
    for (int i = 0; i < layers.length; i++) {
        int nParams = layers[i].numParams();
        //System.out.println("Number of parameters in layer " + i + ": " + nParams);
        totalNumParams += nParams;
    }
    //System.out.println("Total number of network parameters: " + totalNumParams);

    // here the actual learning takes place
    net.fit(ds);

    // create output for every training sample
    INDArray output = net.output(ds.getFeatureMatrix());
    //System.out.println("output: " + output);

    for (int i = 0; i < output.rows(); i++) {
        String actual = ds.getLabels().getRow(i).toString().trim();
        String predicted = output.getRow(i).toString().trim();
        //System.out.println("actual " + actual + " vs predicted " + predicted);
    }

    // let Evaluation prints stats how often the right output had the
    // highest value
    Evaluation eval = new Evaluation(2);
    eval.eval(ds.getLabels(), output);
    System.out.println(eval.stats());

    //displayNetwork(net);

    // Make prediction
    INDArray example = Nd4j.zeros(1, 2);
    // create first dataset
    // when first input=0 and second input=0
    example.putScalar(new int[] { 0, 0 }, 0);
    example.putScalar(new int[] { 0, 1 }, 1);

    int[] prediction = net.predict(example);

    System.out.println("prediction for 0, 1: " + prediction[0]);

    return net;
}

From source file:org.eigengo.rsa.identity.v100.AlexNet.java

License:Open Source License

public MultiLayerConfiguration conf() {
    double nonZeroBias = 1;
    double dropOut = 0.5;
    SubsamplingLayer.PoolingType poolingType = SubsamplingLayer.PoolingType.MAX;

    // TODO split and link kernel maps on GPUs - 2nd, 4th, 5th convolution should only connect maps on the same gpu, 3rd connects to all in 2nd
    MultiLayerConfiguration.Builder conf = new NeuralNetConfiguration.Builder().seed(seed)
            .weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0.0, 0.01)).activation("relu")
            .updater(Updater.NESTEROVS).iterations(iterations)
            .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) // normalize to prevent vanishing or exploding gradients
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(1e-2)
            .biasLearningRate(1e-2 * 2).learningRateDecayPolicy(LearningRatePolicy.Step).lrPolicyDecayRate(0.1)
            .lrPolicySteps(100000).regularization(true).l2(5 * 1e-4).momentum(0.9).miniBatch(false).list()
            .layer(0,/*from  ww w.  ja  v a 2 s.co m*/
                    new ConvolutionLayer.Builder(new int[] { 11, 11 }, new int[] { 4, 4 }, new int[] { 3, 3 })
                            .name("cnn1").nIn(channels).nOut(96).build())
            .layer(1, new LocalResponseNormalization.Builder().name("lrn1").build())
            .layer(2,
                    new SubsamplingLayer.Builder(poolingType, new int[] { 3, 3 }, new int[] { 2, 2 })
                            .name("maxpool1").build())
            .layer(3,
                    new ConvolutionLayer.Builder(new int[] { 5, 5 }, new int[] { 1, 1 }, new int[] { 2, 2 })
                            .name("cnn2").nOut(256).biasInit(nonZeroBias).build())
            .layer(4,
                    new LocalResponseNormalization.Builder().name("lrn2").k(2).n(5).alpha(1e-4).beta(0.75)
                            .build())
            .layer(5,
                    new SubsamplingLayer.Builder(poolingType, new int[] { 3, 3 }, new int[] { 2, 2 })
                            .name("maxpool2").build())
            .layer(6,
                    new ConvolutionLayer.Builder(new int[] { 3, 3 }, new int[] { 1, 1 }, new int[] { 1, 1 })
                            .name("cnn3").nOut(384).build())
            .layer(7,
                    new ConvolutionLayer.Builder(new int[] { 3, 3 }, new int[] { 1, 1 }, new int[] { 1, 1 })
                            .name("cnn4").nOut(384).biasInit(nonZeroBias).build())
            .layer(8,
                    new ConvolutionLayer.Builder(new int[] { 3, 3 }, new int[] { 1, 1 }, new int[] { 1, 1 })
                            .name("cnn5").nOut(256).biasInit(nonZeroBias).build())
            .layer(9,
                    new SubsamplingLayer.Builder(poolingType, new int[] { 3, 3 }, new int[] { 2, 2 })
                            .name("maxpool3").build())
            .layer(10,
                    new DenseLayer.Builder().name("ffn1").nOut(4096).dist(new GaussianDistribution(0, 0.005))
                            .biasInit(nonZeroBias).dropOut(dropOut).build())
            .layer(11,
                    new DenseLayer.Builder().name("ffn2").nOut(4096).dist(new GaussianDistribution(0, 0.005))
                            .biasInit(nonZeroBias).dropOut(dropOut).build())
            .layer(12,
                    new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).name("output")
                            .nOut(numLabels).activation("softmax").build())
            .backprop(true).pretrain(false).cnnInputSize(height, width, channels);

    return conf.build();
}

From source file:org.ensor.fftmusings.autoencoder.RNNTrainer2.java

public static void main(String[] args) throws Exception {

    MultiLayerNetwork stackedAutoencoder = ModelSerializer.restoreMultiLayerNetwork("stack.rnn");

    Random rng = new Random();

    RNNIterator iter = new RNNIterator(stackedAutoencoder, rng, 100, 100, System.out);

    int labels = iter.inputColumns();
    int lstmLayerSize = 200;
    int bttLength = 50;

    //Set up network configuration:
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).learningRate(0.1)
            .rmsDecay(0.95).seed(12345).iterations(1).regularization(true).l2(0.001).list()
            .layer(0,//w  w w. j  av a2 s .c o m
                    new GravesLSTM.Builder().nIn(labels).nOut(lstmLayerSize).updater(Updater.RMSPROP)
                            .activation(Activation.TANH).weightInit(WeightInit.DISTRIBUTION)
                            .dist(new UniformDistribution(-0.08, 0.08)).build())
            .layer(1,
                    new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize).updater(Updater.RMSPROP)
                            .activation(Activation.TANH).weightInit(WeightInit.DISTRIBUTION)
                            .dist(new UniformDistribution(-0.08, 0.08)).build())
            .layer(2,
                    new RnnOutputLayer.Builder().nIn(lstmLayerSize).nOut(labels).lossFunction(LossFunction.MSE)
                            .updater(Updater.RMSPROP).weightInit(WeightInit.DISTRIBUTION)
                            .dist(new UniformDistribution(-0.08, 0.08)).build())
            .pretrain(false).backprop(true).backpropType(BackpropType.TruncatedBPTT)
            .tBPTTForwardLength(bttLength).tBPTTBackwardLength(bttLength).build();

    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();
    model.setListeners(new ScoreIterationListener(System.out));

    for (int epoch = 0; epoch < 300; epoch++) {
        model.fit(iter);
        iter.reset();
        evaluateModel(model, stackedAutoencoder, rng, epoch);
        ModelSerializer.writeModel(model, "stack-timeseries.rnn", true);
    }
}

From source file:org.ensor.fftmusings.rnn.qft.SampleLSTM.java

public static MultiLayerNetwork create(File modelFilename, DataSetIterator iter) throws IOException {

    if (modelFilename.exists()) {
        return load(modelFilename);
    }/*w  ww  .j a  v  a2 s . co  m*/

    int nOut = iter.totalOutcomes();

    //Set up network configuration:
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
            .learningRate(0.01).rmsDecay(0.95).seed(12345).regularization(true).l2(0.001).list()
            .layer(0, new GravesLSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize)
                    .updater(Updater.RMSPROP).activation(Activation.TANH).weightInit(WeightInit.DISTRIBUTION)
                    .dist(new UniformDistribution(-0.08, 0.08)).build())
            .layer(1,
                    new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize).updater(Updater.RMSPROP)
                            .activation(Activation.TANH).weightInit(WeightInit.DISTRIBUTION)
                            .dist(new UniformDistribution(-0.08, 0.08)).build())
            .layer(2,
                    new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX) //MCXENT + softmax for classification
                            .updater(Updater.RMSPROP).nIn(lstmLayerSize).nOut(nOut)
                            .weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(-0.08, 0.08))
                            .build())
            .pretrain(false).backprop(true).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    net.setListeners(new ScoreIterationListener());

    ModelSerializer.writeModel(net, modelFilename, true);

    return net;
}

From source file:org.ensor.fftmusings.rnn.RNNFactory.java

public static MultiLayerNetwork create(File modelFilename, CharacterIterator iter) throws IOException {

    if (modelFilename.exists()) {
        MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFilename);
        net.clear();//w w w. ja va  2  s.co  m
        net.setListeners(new ScoreIterationListener(System.out));
        return net;
    }

    int nOut = iter.totalOutcomes();

    //Set up network configuration:
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).learningRate(0.1)
            .rmsDecay(0.95).seed(12345).regularization(true).l2(0.001).list()
            .layer(0, new GravesLSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize)
                    .updater(Updater.RMSPROP).activation(Activation.TANH).weightInit(WeightInit.DISTRIBUTION)
                    .dist(new UniformDistribution(-0.08, 0.08)).build())
            .layer(1,
                    new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize).updater(Updater.RMSPROP)
                            .activation(Activation.TANH).weightInit(WeightInit.DISTRIBUTION)
                            .dist(new UniformDistribution(-0.08, 0.08)).build())
            .layer(2,
                    new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX) //MCXENT + softmax for classification
                            .updater(Updater.RMSPROP).nIn(lstmLayerSize).nOut(nOut)
                            .weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(-0.08, 0.08))
                            .build())
            .pretrain(false).backprop(true).backpropType(BackpropType.TruncatedBPTT).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    net.setListeners(new ScoreIterationListener(System.out));

    ModelSerializer.writeModel(net, modelFilename, true);

    return net;
}

From source file:org.wso2.carbon.ml.rest.api.neuralNetworks.FeedForwardNetwork.java

License:Open Source License

/**
 * method to map user selected WeightInit Algorithm to WeightInit object.
 * @param weightinit/*from w w w  .  j a  v a2 s  .c o m*/
 * @return an WeightInit object .
 */
WeightInit mapWeightInit(String weightinit) {

    WeightInit weightInitAlgo = null;

    switch (weightinit) {

    case "Distribution":
        weightInitAlgo = WeightInit.DISTRIBUTION;
        break;

    case "Normalized":
        weightInitAlgo = WeightInit.NORMALIZED;
        break;

    case "Size":
        weightInitAlgo = WeightInit.SIZE;
        break;

    case "Uniform":
        weightInitAlgo = WeightInit.UNIFORM;
        break;

    case "Vi":
        weightInitAlgo = WeightInit.VI;
        break;

    case "Zero":
        weightInitAlgo = WeightInit.ZERO;
        break;

    case "Xavier":
        weightInitAlgo = WeightInit.XAVIER;
        break;

    case "RELU":
        weightInitAlgo = WeightInit.RELU;
        break;

    default:
        weightInitAlgo = null;
        break;
    }

    return weightInitAlgo;
}