Example usage for org.deeplearning4j.nn.conf MultiLayerConfiguration toString

List of usage examples for org.deeplearning4j.nn.conf MultiLayerConfiguration toString

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.conf MultiLayerConfiguration toString.

Prototype

@Override
    public String toString() 

Source Link

Usage

From source file:weka.classifiers.functions.DL4JHomogenousMultiLayerClassifier.java

License:Open Source License

/**
 * Generates the classifier./*from w w  w  . j  a v a 2  s.c  om*/
 *
 * @param data set of instances serving as training data
 * @exception Exception if the classifier has not been generated successfully
 */
@Override
public void buildClassifier(Instances data) throws Exception {
    m_zeroR = null;

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    if (data.numInstances() == 0 || data.numAttributes() < 2) {
        m_zeroR.buildClassifier(data);
        return;
    }

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    m_replaceMissing = new ReplaceMissingValues();
    m_replaceMissing.setInputFormat(data);
    data = Filter.useFilter(data, m_replaceMissing);
    if (m_standardizeInsteadOfNormalize) {
        m_normalize = new Standardize();
    } else {
        m_normalize = new Normalize();
    }
    m_normalize.setInputFormat(data);
    data = Filter.useFilter(data, m_normalize);
    m_nominalToBinary = new NominalToBinary();
    m_nominalToBinary.setInputFormat(data);
    data = Filter.useFilter(data, m_nominalToBinary);
    data.randomize(new Random(123));

    DataSet dataset = Utils.instancesToDataSet(data);

    m_layerConfig = new NeuralNetConfiguration();
    RandomGenerator gen = new MersenneTwister(123);
    m_layerConfig.setRng(gen);
    m_layerConfig.setBatchSize(data.numInstances());
    m_layerConfig.setNumIterations(m_iterations);
    m_layerConfig.setLr(m_learningRate);
    m_layerConfig.setMomentum(m_momentum);
    m_layerConfig.setDropOut(m_dropOut);
    m_layerConfig.setConstrainGradientToUnitNorm(m_constrainGradientToUnitNorm);
    m_layerConfig.setUseRegularization(m_useRegularization);
    m_layerConfig.setL2(m_l2);
    m_layerConfig.setK(m_k);
    m_layerType.configureLayerFactory(m_layerConfig);
    m_activationFunction.configureActivationFunction(m_layerConfig);
    m_layerConfig.setOptimizationAlgo(m_optimizationStrategy);
    m_layerConfig.setLossFunction(m_lossFunction);
    m_layerConfig.setVisibleUnit(m_rbmVisibleUnit);
    m_layerConfig.setHiddenUnit(m_rbmHiddenUnit);

    m_layerConfig.setWeightInit(m_weightInit);

    // TODO Distributions (for sampling?). These are commons.math3 in this
    // version of DL4J but in master they have moved to
    // org.nd4j.linalg.api.rng.distribution.Distribution.

    m_layerConfig.setnIn(data.numAttributes() - 1);
    m_layerConfig.setnOut(data.classAttribute().numValues());

    if (m_hiddenLayerSizes == null || m_hiddenLayerSizes.length() == 0) {
        throw new Exception("Must specify at least one hidden layer");
    }
    String[] split = m_hiddenLayerSizes.split(",");
    int size = split.length;
    if (size < 1) {
        throw new Exception("Must have at least one hidden layer in the network");
    }
    int[] layerSizes = new int[size];
    size++; // +1 for the output layer
    List<NeuralNetConfiguration> layerConfs = new ArrayList<NeuralNetConfiguration>();
    for (int i = 0; i < size; i++) {
        layerConfs.add(m_layerConfig.clone());
        if (i < size - 1) {
            try {
                layerSizes[i] = Integer.parseInt(split[i]);
            } catch (NumberFormatException e) {
                throw new Exception(e);
            }
        }
    }

    // make sure our net is a classifier!
    layerConfs.get(size - 1).setWeightInit(WeightInit.ZERO);
    layerConfs.get(size - 1).setLayerFactory(LayerFactories.getFactory(OutputLayer.class));
    layerConfs.get(size - 1).setActivationFunction(new SoftMax());
    layerConfs.get(size - 1).setLossFunction(LossFunctions.LossFunction.MCXENT);

    System.err.println("Layer conf:\n" + m_layerConfig.toString() + "\n\n");

    MultiLayerConfiguration.Builder multiBuilder = new MultiLayerConfiguration.Builder();
    multiBuilder.confs(layerConfs);
    multiBuilder.hiddenLayerSizes(layerSizes);
    multiBuilder.useDropConnect(m_useDropConnect);
    multiBuilder.pretrain(!m_dontPretrain);
    multiBuilder.useRBMPropUpAsActivations(!m_dontUseRBMPropUpAsActivations);
    multiBuilder.dampingFactor(m_dampingFactor);
    MultiLayerConfiguration multiConf = multiBuilder.build();

    System.err.println("Multilayer conf:\n" + multiConf.toString());

    m_network = new MultiLayerNetwork(multiConf);
    m_network.fit(dataset);
}