Example usage for org.deeplearning4j.nn.conf MultiLayerConfiguration.Builder confs

List of usage examples for org.deeplearning4j.nn.conf MultiLayerConfiguration.Builder confs

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.conf MultiLayerConfiguration.Builder confs.

Prototype

List confs

To view the source code for org.deeplearning4j.nn.conf MultiLayerConfiguration.Builder confs.

Click Source Link

Usage

From source file:weka.classifiers.functions.DL4JHomogenousMultiLayerClassifier.java

License:Open Source License

/**
 * Generates the classifier.//from   w ww . j  a  v  a 2 s  .com
 *
 * @param data set of instances serving as training data
 * @exception Exception if the classifier has not been generated successfully
 */
@Override
public void buildClassifier(Instances data) throws Exception {
    m_zeroR = null;

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    if (data.numInstances() == 0 || data.numAttributes() < 2) {
        m_zeroR.buildClassifier(data);
        return;
    }

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    m_replaceMissing = new ReplaceMissingValues();
    m_replaceMissing.setInputFormat(data);
    data = Filter.useFilter(data, m_replaceMissing);
    if (m_standardizeInsteadOfNormalize) {
        m_normalize = new Standardize();
    } else {
        m_normalize = new Normalize();
    }
    m_normalize.setInputFormat(data);
    data = Filter.useFilter(data, m_normalize);
    m_nominalToBinary = new NominalToBinary();
    m_nominalToBinary.setInputFormat(data);
    data = Filter.useFilter(data, m_nominalToBinary);
    data.randomize(new Random(123));

    DataSet dataset = Utils.instancesToDataSet(data);

    m_layerConfig = new NeuralNetConfiguration();
    RandomGenerator gen = new MersenneTwister(123);
    m_layerConfig.setRng(gen);
    m_layerConfig.setBatchSize(data.numInstances());
    m_layerConfig.setNumIterations(m_iterations);
    m_layerConfig.setLr(m_learningRate);
    m_layerConfig.setMomentum(m_momentum);
    m_layerConfig.setDropOut(m_dropOut);
    m_layerConfig.setConstrainGradientToUnitNorm(m_constrainGradientToUnitNorm);
    m_layerConfig.setUseRegularization(m_useRegularization);
    m_layerConfig.setL2(m_l2);
    m_layerConfig.setK(m_k);
    m_layerType.configureLayerFactory(m_layerConfig);
    m_activationFunction.configureActivationFunction(m_layerConfig);
    m_layerConfig.setOptimizationAlgo(m_optimizationStrategy);
    m_layerConfig.setLossFunction(m_lossFunction);
    m_layerConfig.setVisibleUnit(m_rbmVisibleUnit);
    m_layerConfig.setHiddenUnit(m_rbmHiddenUnit);

    m_layerConfig.setWeightInit(m_weightInit);

    // TODO Distributions (for sampling?). These are commons.math3 in this
    // version of DL4J but in master they have moved to
    // org.nd4j.linalg.api.rng.distribution.Distribution.

    m_layerConfig.setnIn(data.numAttributes() - 1);
    m_layerConfig.setnOut(data.classAttribute().numValues());

    if (m_hiddenLayerSizes == null || m_hiddenLayerSizes.length() == 0) {
        throw new Exception("Must specify at least one hidden layer");
    }
    String[] split = m_hiddenLayerSizes.split(",");
    int size = split.length;
    if (size < 1) {
        throw new Exception("Must have at least one hidden layer in the network");
    }
    int[] layerSizes = new int[size];
    size++; // +1 for the output layer
    List<NeuralNetConfiguration> layerConfs = new ArrayList<NeuralNetConfiguration>();
    for (int i = 0; i < size; i++) {
        layerConfs.add(m_layerConfig.clone());
        if (i < size - 1) {
            try {
                layerSizes[i] = Integer.parseInt(split[i]);
            } catch (NumberFormatException e) {
                throw new Exception(e);
            }
        }
    }

    // make sure our net is a classifier!
    layerConfs.get(size - 1).setWeightInit(WeightInit.ZERO);
    layerConfs.get(size - 1).setLayerFactory(LayerFactories.getFactory(OutputLayer.class));
    layerConfs.get(size - 1).setActivationFunction(new SoftMax());
    layerConfs.get(size - 1).setLossFunction(LossFunctions.LossFunction.MCXENT);

    System.err.println("Layer conf:\n" + m_layerConfig.toString() + "\n\n");

    MultiLayerConfiguration.Builder multiBuilder = new MultiLayerConfiguration.Builder();
    multiBuilder.confs(layerConfs);
    multiBuilder.hiddenLayerSizes(layerSizes);
    multiBuilder.useDropConnect(m_useDropConnect);
    multiBuilder.pretrain(!m_dontPretrain);
    multiBuilder.useRBMPropUpAsActivations(!m_dontUseRBMPropUpAsActivations);
    multiBuilder.dampingFactor(m_dampingFactor);
    MultiLayerConfiguration multiConf = multiBuilder.build();

    System.err.println("Multilayer conf:\n" + multiConf.toString());

    m_network = new MultiLayerNetwork(multiConf);
    m_network.fit(dataset);
}