Example usage for org.deeplearning4j.nn.weights WeightInit ZERO

List of usage examples for org.deeplearning4j.nn.weights WeightInit ZERO

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.weights WeightInit ZERO.

Prototype

WeightInit ZERO

To view the source code for org.deeplearning4j.nn.weights WeightInit ZERO.

Click Source Link

Usage

From source file:org.wso2.carbon.ml.rest.api.neuralNetworks.FeedForwardNetwork.java

License:Open Source License

/**
 * method to map user selected WeightInit Algorithm to WeightInit object.
 * @param weightinit/*from   ww  w  .j  a v  a  2  s.co  m*/
 * @return an WeightInit object .
 */
WeightInit mapWeightInit(String weightinit) {

    WeightInit weightInitAlgo = null;

    switch (weightinit) {

    case "Distribution":
        weightInitAlgo = WeightInit.DISTRIBUTION;
        break;

    case "Normalized":
        weightInitAlgo = WeightInit.NORMALIZED;
        break;

    case "Size":
        weightInitAlgo = WeightInit.SIZE;
        break;

    case "Uniform":
        weightInitAlgo = WeightInit.UNIFORM;
        break;

    case "Vi":
        weightInitAlgo = WeightInit.VI;
        break;

    case "Zero":
        weightInitAlgo = WeightInit.ZERO;
        break;

    case "Xavier":
        weightInitAlgo = WeightInit.XAVIER;
        break;

    case "RELU":
        weightInitAlgo = WeightInit.RELU;
        break;

    default:
        weightInitAlgo = null;
        break;
    }

    return weightInitAlgo;
}

From source file:weka.classifiers.functions.DL4JHomogenousMultiLayerClassifier.java

License:Open Source License

/**
 * Generates the classifier./*from   ww  w . j  av a 2  s. com*/
 *
 * @param data set of instances serving as training data
 * @exception Exception if the classifier has not been generated successfully
 */
@Override
public void buildClassifier(Instances data) throws Exception {
    m_zeroR = null;

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    if (data.numInstances() == 0 || data.numAttributes() < 2) {
        m_zeroR.buildClassifier(data);
        return;
    }

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    m_replaceMissing = new ReplaceMissingValues();
    m_replaceMissing.setInputFormat(data);
    data = Filter.useFilter(data, m_replaceMissing);
    if (m_standardizeInsteadOfNormalize) {
        m_normalize = new Standardize();
    } else {
        m_normalize = new Normalize();
    }
    m_normalize.setInputFormat(data);
    data = Filter.useFilter(data, m_normalize);
    m_nominalToBinary = new NominalToBinary();
    m_nominalToBinary.setInputFormat(data);
    data = Filter.useFilter(data, m_nominalToBinary);
    data.randomize(new Random(123));

    DataSet dataset = Utils.instancesToDataSet(data);

    m_layerConfig = new NeuralNetConfiguration();
    RandomGenerator gen = new MersenneTwister(123);
    m_layerConfig.setRng(gen);
    m_layerConfig.setBatchSize(data.numInstances());
    m_layerConfig.setNumIterations(m_iterations);
    m_layerConfig.setLr(m_learningRate);
    m_layerConfig.setMomentum(m_momentum);
    m_layerConfig.setDropOut(m_dropOut);
    m_layerConfig.setConstrainGradientToUnitNorm(m_constrainGradientToUnitNorm);
    m_layerConfig.setUseRegularization(m_useRegularization);
    m_layerConfig.setL2(m_l2);
    m_layerConfig.setK(m_k);
    m_layerType.configureLayerFactory(m_layerConfig);
    m_activationFunction.configureActivationFunction(m_layerConfig);
    m_layerConfig.setOptimizationAlgo(m_optimizationStrategy);
    m_layerConfig.setLossFunction(m_lossFunction);
    m_layerConfig.setVisibleUnit(m_rbmVisibleUnit);
    m_layerConfig.setHiddenUnit(m_rbmHiddenUnit);

    m_layerConfig.setWeightInit(m_weightInit);

    // TODO Distributions (for sampling?). These are commons.math3 in this
    // version of DL4J but in master they have moved to
    // org.nd4j.linalg.api.rng.distribution.Distribution.

    m_layerConfig.setnIn(data.numAttributes() - 1);
    m_layerConfig.setnOut(data.classAttribute().numValues());

    if (m_hiddenLayerSizes == null || m_hiddenLayerSizes.length() == 0) {
        throw new Exception("Must specify at least one hidden layer");
    }
    String[] split = m_hiddenLayerSizes.split(",");
    int size = split.length;
    if (size < 1) {
        throw new Exception("Must have at least one hidden layer in the network");
    }
    int[] layerSizes = new int[size];
    size++; // +1 for the output layer
    List<NeuralNetConfiguration> layerConfs = new ArrayList<NeuralNetConfiguration>();
    for (int i = 0; i < size; i++) {
        layerConfs.add(m_layerConfig.clone());
        if (i < size - 1) {
            try {
                layerSizes[i] = Integer.parseInt(split[i]);
            } catch (NumberFormatException e) {
                throw new Exception(e);
            }
        }
    }

    // make sure our net is a classifier!
    layerConfs.get(size - 1).setWeightInit(WeightInit.ZERO);
    layerConfs.get(size - 1).setLayerFactory(LayerFactories.getFactory(OutputLayer.class));
    layerConfs.get(size - 1).setActivationFunction(new SoftMax());
    layerConfs.get(size - 1).setLossFunction(LossFunctions.LossFunction.MCXENT);

    System.err.println("Layer conf:\n" + m_layerConfig.toString() + "\n\n");

    MultiLayerConfiguration.Builder multiBuilder = new MultiLayerConfiguration.Builder();
    multiBuilder.confs(layerConfs);
    multiBuilder.hiddenLayerSizes(layerSizes);
    multiBuilder.useDropConnect(m_useDropConnect);
    multiBuilder.pretrain(!m_dontPretrain);
    multiBuilder.useRBMPropUpAsActivations(!m_dontUseRBMPropUpAsActivations);
    multiBuilder.dampingFactor(m_dampingFactor);
    MultiLayerConfiguration multiConf = multiBuilder.build();

    System.err.println("Multilayer conf:\n" + multiConf.toString());

    m_network = new MultiLayerNetwork(multiConf);
    m_network.fit(dataset);
}