Example usage for org.deeplearning4j.nn.conf MultiLayerConfiguration toJson

List of usage examples for org.deeplearning4j.nn.conf MultiLayerConfiguration toJson

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.conf MultiLayerConfiguration toJson.

Prototype

public String toJson() 

Source Link

Usage

From source file:weka.classifiers.functions.Dl4jMlpClassifier.java

License:Open Source License

/**
 * The method used to train the classifier.
 *
 * @param data set of instances serving as training data
 * @throws Exception if something goes wrong in the training process
 *//*  w  w w .  j a v  a2s .co m*/
@Override
public void buildClassifier(Instances data) throws Exception {
    ClassLoader orig = Thread.currentThread().getContextClassLoader();
    try {
        Thread.currentThread().setContextClassLoader(this.getClass().getClassLoader());
        // Can classifier handle the data?
        getCapabilities().testWithFail(data);

        // Check basic network structure
        if (m_layers.length == 0) {
            throw new Exception("No layers have been added!");
        }
        if (!(m_layers[m_layers.length - 1] instanceof OutputLayer)) {
            throw new Exception("Last layer in network must be an output layer!");
        }

        // Remove instances with missing class and check that instances and
        // predictor attributes remain.
        data = new Instances(data);
        data.deleteWithMissingClass();
        m_zeroR = null;
        if (data.numInstances() == 0 || data.numAttributes() < 2) {
            m_zeroR = new ZeroR();
            m_zeroR.buildClassifier(data);
            return;
        }

        // Replace missing values
        m_replaceMissing = new ReplaceMissingValues();
        m_replaceMissing.setInputFormat(data);
        data = Filter.useFilter(data, m_replaceMissing);

        // Retrieve two different class values used to determine filter
        // transformation
        double y0 = data.instance(0).classValue();
        int index = 1;
        while (index < data.numInstances() && data.instance(index).classValue() == y0) {
            index++;
        }
        if (index == data.numInstances()) {
            // degenerate case, all class values are equal
            // we don't want to deal with this, too much hassle
            throw new Exception("All class values are the same. At least two class values should be different");
        }
        double y1 = data.instance(index).classValue();

        // Replace nominal attributes by binary numeric attributes.
        m_nominalToBinary = new NominalToBinary();
        m_nominalToBinary.setInputFormat(data);
        data = Filter.useFilter(data, m_nominalToBinary);

        // Standardize or normalize (as requested), including the class
        if (m_standardizeInsteadOfNormalize) {
            m_normalize = new Standardize();
            m_normalize.setOptions(new String[] { "-unset-class-temporarily" });
        } else {
            m_normalize = new Normalize();
        }
        m_normalize.setInputFormat(data);
        data = Filter.useFilter(data, m_normalize);

        double z0 = data.instance(0).classValue();
        double z1 = data.instance(index).classValue();
        m_x1 = (y0 - y1) / (z0 - z1); // no division by zero, since y0 != y1
                                      // guaranteed => z0 != z1 ???
        m_x0 = (y0 - m_x1 * z0); // = y1 - m_x1 * z1

        // Randomize the data, just in case
        Random rand = new Random(getSeed());
        data.randomize(rand);

        // Initialize random number generator for construction of network
        NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
        if (getOptimizationAlgorithm() == null) {
            builder.setOptimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
        } else {
            builder.setOptimizationAlgo(getOptimizationAlgorithm());
        }
        builder.setSeed(rand.nextInt());

        // Construct the mlp configuration
        ListBuilder ip = builder.list(getLayers());
        int numInputAttributes = getDataSetIterator().getNumAttributes(data);

        // Connect up the layers appropriately
        for (int x = 0; x < m_layers.length; x++) {

            // Is this the first hidden layer?
            if (x == 0) {
                setNumIncoming(m_layers[x], numInputAttributes);
            } else {
                setNumIncoming(m_layers[x], getNumUnits(m_layers[x - 1]));
            }

            // Is this the output layer?
            if (x == m_layers.length - 1) {
                ((OutputLayer) m_layers[x]).setNOut(data.numClasses());
            }
            ip = ip.layer(x, m_layers[x]);
        }

        // If we have a convolutional network
        if (getDataSetIterator() instanceof ImageDataSetIterator) {
            ImageDataSetIterator idsi = (ImageDataSetIterator) getDataSetIterator();
            ip.setInputType(
                    InputType.convolutionalFlat(idsi.getWidth(), idsi.getHeight(), idsi.getNumChannels()));
        } else if (getDataSetIterator() instanceof ConvolutionalInstancesIterator) {
            ConvolutionalInstancesIterator cii = (ConvolutionalInstancesIterator) getDataSetIterator();
            ip.setInputType(InputType.convolutionalFlat(cii.getWidth(), cii.getHeight(), cii.getNumChannels()));
        }

        ip = ip.pretrain(false).backprop(true);

        MultiLayerConfiguration conf = ip.build();

        if (getDebug()) {
            System.err.println(conf.toJson());
        }

        // build the network
        m_model = new MultiLayerNetwork(conf);
        m_model.init();

        if (getDebug()) {
            System.err.println(m_model.conf().toYaml());
        }

        ArrayList<IterationListener> listeners = new ArrayList<IterationListener>();
        listeners.add(
                new ScoreIterationListener(data.numInstances() / getDataSetIterator().getTrainBatchSize()));

        // if the log file doesn't point to a directory, set up the listener
        if (getLogFile() != null && !getLogFile().isDirectory()) {
            int numMiniBatches = (int) Math
                    .ceil(((double) data.numInstances()) / ((double) getDataSetIterator().getTrainBatchSize()));
            listeners.add(new FileIterationListener(getLogFile().getAbsolutePath(), numMiniBatches));
        }

        m_model.setListeners(listeners);

        // Abusing the MultipleEpochsIterator because it splits the data into
        // batches
        DataSetIterator iter = getDataSetIterator().getIterator(data, getSeed());
        for (int i = 0; i < getNumEpochs(); i++) {
            m_model.fit(iter); // Note that this calls the reset() method of the
                               // iterator
            if (getDebug()) {
                m_log.info("*** Completed epoch {} ***", i + 1);
            }
            iter.reset();
        }
    } finally {
        Thread.currentThread().setContextClassLoader(orig);
    }
}