List of usage examples for org.deeplearning4j.nn.conf MultiLayerConfiguration toJson
public String toJson()
From source file:weka.classifiers.functions.Dl4jMlpClassifier.java
License:Open Source License
/** * The method used to train the classifier. * * @param data set of instances serving as training data * @throws Exception if something goes wrong in the training process *//* w w w . j a v a2s .co m*/ @Override public void buildClassifier(Instances data) throws Exception { ClassLoader orig = Thread.currentThread().getContextClassLoader(); try { Thread.currentThread().setContextClassLoader(this.getClass().getClassLoader()); // Can classifier handle the data? getCapabilities().testWithFail(data); // Check basic network structure if (m_layers.length == 0) { throw new Exception("No layers have been added!"); } if (!(m_layers[m_layers.length - 1] instanceof OutputLayer)) { throw new Exception("Last layer in network must be an output layer!"); } // Remove instances with missing class and check that instances and // predictor attributes remain. data = new Instances(data); data.deleteWithMissingClass(); m_zeroR = null; if (data.numInstances() == 0 || data.numAttributes() < 2) { m_zeroR = new ZeroR(); m_zeroR.buildClassifier(data); return; } // Replace missing values m_replaceMissing = new ReplaceMissingValues(); m_replaceMissing.setInputFormat(data); data = Filter.useFilter(data, m_replaceMissing); // Retrieve two different class values used to determine filter // transformation double y0 = data.instance(0).classValue(); int index = 1; while (index < data.numInstances() && data.instance(index).classValue() == y0) { index++; } if (index == data.numInstances()) { // degenerate case, all class values are equal // we don't want to deal with this, too much hassle throw new Exception("All class values are the same. At least two class values should be different"); } double y1 = data.instance(index).classValue(); // Replace nominal attributes by binary numeric attributes. m_nominalToBinary = new NominalToBinary(); m_nominalToBinary.setInputFormat(data); data = Filter.useFilter(data, m_nominalToBinary); // Standardize or normalize (as requested), including the class if (m_standardizeInsteadOfNormalize) { m_normalize = new Standardize(); m_normalize.setOptions(new String[] { "-unset-class-temporarily" }); } else { m_normalize = new Normalize(); } m_normalize.setInputFormat(data); data = Filter.useFilter(data, m_normalize); double z0 = data.instance(0).classValue(); double z1 = data.instance(index).classValue(); m_x1 = (y0 - y1) / (z0 - z1); // no division by zero, since y0 != y1 // guaranteed => z0 != z1 ??? m_x0 = (y0 - m_x1 * z0); // = y1 - m_x1 * z1 // Randomize the data, just in case Random rand = new Random(getSeed()); data.randomize(rand); // Initialize random number generator for construction of network NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder(); if (getOptimizationAlgorithm() == null) { builder.setOptimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); } else { builder.setOptimizationAlgo(getOptimizationAlgorithm()); } builder.setSeed(rand.nextInt()); // Construct the mlp configuration ListBuilder ip = builder.list(getLayers()); int numInputAttributes = getDataSetIterator().getNumAttributes(data); // Connect up the layers appropriately for (int x = 0; x < m_layers.length; x++) { // Is this the first hidden layer? if (x == 0) { setNumIncoming(m_layers[x], numInputAttributes); } else { setNumIncoming(m_layers[x], getNumUnits(m_layers[x - 1])); } // Is this the output layer? if (x == m_layers.length - 1) { ((OutputLayer) m_layers[x]).setNOut(data.numClasses()); } ip = ip.layer(x, m_layers[x]); } // If we have a convolutional network if (getDataSetIterator() instanceof ImageDataSetIterator) { ImageDataSetIterator idsi = (ImageDataSetIterator) getDataSetIterator(); ip.setInputType( InputType.convolutionalFlat(idsi.getWidth(), idsi.getHeight(), idsi.getNumChannels())); } else if (getDataSetIterator() instanceof ConvolutionalInstancesIterator) { ConvolutionalInstancesIterator cii = (ConvolutionalInstancesIterator) getDataSetIterator(); ip.setInputType(InputType.convolutionalFlat(cii.getWidth(), cii.getHeight(), cii.getNumChannels())); } ip = ip.pretrain(false).backprop(true); MultiLayerConfiguration conf = ip.build(); if (getDebug()) { System.err.println(conf.toJson()); } // build the network m_model = new MultiLayerNetwork(conf); m_model.init(); if (getDebug()) { System.err.println(m_model.conf().toYaml()); } ArrayList<IterationListener> listeners = new ArrayList<IterationListener>(); listeners.add( new ScoreIterationListener(data.numInstances() / getDataSetIterator().getTrainBatchSize())); // if the log file doesn't point to a directory, set up the listener if (getLogFile() != null && !getLogFile().isDirectory()) { int numMiniBatches = (int) Math .ceil(((double) data.numInstances()) / ((double) getDataSetIterator().getTrainBatchSize())); listeners.add(new FileIterationListener(getLogFile().getAbsolutePath(), numMiniBatches)); } m_model.setListeners(listeners); // Abusing the MultipleEpochsIterator because it splits the data into // batches DataSetIterator iter = getDataSetIterator().getIterator(data, getSeed()); for (int i = 0; i < getNumEpochs(); i++) { m_model.fit(iter); // Note that this calls the reset() method of the // iterator if (getDebug()) { m_log.info("*** Completed epoch {} ***", i + 1); } iter.reset(); } } finally { Thread.currentThread().setContextClassLoader(orig); } }