List of usage examples for org.deeplearning4j.nn.conf MultiLayerConfiguration toString
@Override
public String toString()
From source file:weka.classifiers.functions.DL4JHomogenousMultiLayerClassifier.java
License:Open Source License
/** * Generates the classifier./*from w w w . j a v a 2 s.c om*/ * * @param data set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ @Override public void buildClassifier(Instances data) throws Exception { m_zeroR = null; // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); if (data.numInstances() == 0 || data.numAttributes() < 2) { m_zeroR.buildClassifier(data); return; } // can classifier handle the data? getCapabilities().testWithFail(data); m_replaceMissing = new ReplaceMissingValues(); m_replaceMissing.setInputFormat(data); data = Filter.useFilter(data, m_replaceMissing); if (m_standardizeInsteadOfNormalize) { m_normalize = new Standardize(); } else { m_normalize = new Normalize(); } m_normalize.setInputFormat(data); data = Filter.useFilter(data, m_normalize); m_nominalToBinary = new NominalToBinary(); m_nominalToBinary.setInputFormat(data); data = Filter.useFilter(data, m_nominalToBinary); data.randomize(new Random(123)); DataSet dataset = Utils.instancesToDataSet(data); m_layerConfig = new NeuralNetConfiguration(); RandomGenerator gen = new MersenneTwister(123); m_layerConfig.setRng(gen); m_layerConfig.setBatchSize(data.numInstances()); m_layerConfig.setNumIterations(m_iterations); m_layerConfig.setLr(m_learningRate); m_layerConfig.setMomentum(m_momentum); m_layerConfig.setDropOut(m_dropOut); m_layerConfig.setConstrainGradientToUnitNorm(m_constrainGradientToUnitNorm); m_layerConfig.setUseRegularization(m_useRegularization); m_layerConfig.setL2(m_l2); m_layerConfig.setK(m_k); m_layerType.configureLayerFactory(m_layerConfig); m_activationFunction.configureActivationFunction(m_layerConfig); m_layerConfig.setOptimizationAlgo(m_optimizationStrategy); m_layerConfig.setLossFunction(m_lossFunction); m_layerConfig.setVisibleUnit(m_rbmVisibleUnit); m_layerConfig.setHiddenUnit(m_rbmHiddenUnit); m_layerConfig.setWeightInit(m_weightInit); // TODO Distributions (for sampling?). These are commons.math3 in this // version of DL4J but in master they have moved to // org.nd4j.linalg.api.rng.distribution.Distribution. m_layerConfig.setnIn(data.numAttributes() - 1); m_layerConfig.setnOut(data.classAttribute().numValues()); if (m_hiddenLayerSizes == null || m_hiddenLayerSizes.length() == 0) { throw new Exception("Must specify at least one hidden layer"); } String[] split = m_hiddenLayerSizes.split(","); int size = split.length; if (size < 1) { throw new Exception("Must have at least one hidden layer in the network"); } int[] layerSizes = new int[size]; size++; // +1 for the output layer List<NeuralNetConfiguration> layerConfs = new ArrayList<NeuralNetConfiguration>(); for (int i = 0; i < size; i++) { layerConfs.add(m_layerConfig.clone()); if (i < size - 1) { try { layerSizes[i] = Integer.parseInt(split[i]); } catch (NumberFormatException e) { throw new Exception(e); } } } // make sure our net is a classifier! layerConfs.get(size - 1).setWeightInit(WeightInit.ZERO); layerConfs.get(size - 1).setLayerFactory(LayerFactories.getFactory(OutputLayer.class)); layerConfs.get(size - 1).setActivationFunction(new SoftMax()); layerConfs.get(size - 1).setLossFunction(LossFunctions.LossFunction.MCXENT); System.err.println("Layer conf:\n" + m_layerConfig.toString() + "\n\n"); MultiLayerConfiguration.Builder multiBuilder = new MultiLayerConfiguration.Builder(); multiBuilder.confs(layerConfs); multiBuilder.hiddenLayerSizes(layerSizes); multiBuilder.useDropConnect(m_useDropConnect); multiBuilder.pretrain(!m_dontPretrain); multiBuilder.useRBMPropUpAsActivations(!m_dontUseRBMPropUpAsActivations); multiBuilder.dampingFactor(m_dampingFactor); MultiLayerConfiguration multiConf = multiBuilder.build(); System.err.println("Multilayer conf:\n" + multiConf.toString()); m_network = new MultiLayerNetwork(multiConf); m_network.fit(dataset); }