List of usage examples for org.deeplearning4j.nn.multilayer MultiLayerNetwork getLayers
public synchronized Layer[] getLayers()
From source file:com.javafxpert.neuralnetviz.scenario.XorExample.java
License:Apache License
static void displayNetwork(MultiLayerNetwork mln) { System.out.println("multiLayerNetwork:"); for (Layer layer : mln.getLayers()) { System.out.println("layer # " + layer.paramTable()); INDArray w = layer.getParam(DefaultParamInitializer.WEIGHT_KEY); System.out.println("Weights: " + w); INDArray b = layer.getParam(DefaultParamInitializer.BIAS_KEY); System.out.println("Bias: " + b); }/* w w w.j a v a 2s. c o m*/ }
From source file:org.ensor.fftmusings.rnn.GravesLSTMCharModellingExample.java
public static void main(String[] args) throws Exception { int numEpochs = 30; //Total number of training + sample generation epochs String generationInitialization = null; //Optional character initialization; a random character is used if null int nSamplesToGenerate = 4; //Number of samples to generate after each training epoch int nCharactersToSample = 300; //Length of each sample to generate Random rng = new Random(12345); int miniBatchSize = 32; //Size of mini batch to use when training int examplesPerEpoch = 50 * miniBatchSize; //i.e., how many examples to learn on between generating samples int exampleLength = 100; //Length of each training example //Get a DataSetIterator that handles vectorization of text into something we can use to train // our GravesLSTM network. CharacterIterator iter = getShakespeareIterator(miniBatchSize, exampleLength, examplesPerEpoch); File modelFilename = new File("data/shakespere/shakespere.3.rnn"); MultiLayerNetwork net = RNNFactory.create(modelFilename, iter); //Print the number of parameters in the network (and for each layer) Layer[] layers = net.getLayers(); int totalNumParams = 0; for (int i = 0; i < layers.length; i++) { int nParams = layers[i].numParams(); System.out.println("Number of parameters in layer " + i + ": " + nParams); totalNumParams += nParams;//from w ww . j ava 2s. c o m } System.out.println("Total number of network parameters: " + totalNumParams); //Do training, and then generate and print samples from network for (int i = 0; i < numEpochs; i++) { net.fit(iter); System.out.println("--------------------"); System.out.println("Completed epoch " + i); System.out.println("Sampling characters from network given initialization \"" + (generationInitialization == null ? "" : generationInitialization) + "\""); for (int j = 0; j < nSamplesToGenerate; j++) { String samples = sampleCharactersFromNetwork2(generationInitialization, net, iter, rng, nCharactersToSample); System.out.println("----- Sample " + j + " -----"); System.out.println(samples); System.out.println(); } RNNFactory.persist(modelFilename, net); iter.reset(); //Reset iterator for another epoch } System.out.println("\n\nExample complete"); }
From source file:org.ensor.fftmusings.rnn2.GravesLSTMCharModellingExample.java
public static void main(String[] args) throws Exception { int lstmLayerSize = 200; //Number of units in each GravesLSTM layer int miniBatchSize = 32; //Size of mini batch to use when training int exampleLength = 1000; //Length of each training example sequence to use. This could certainly be increased int tbpttLength = 50; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters int numEpochs = 30; //Total number of training epochs int generateSamplesEveryNMinibatches = 10; //How frequently to generate samples from the network? 1000 characters / 50 tbptt length: 20 parameter updates per minibatch int nSamplesToGenerate = 4; //Number of samples to generate after each training epoch int nCharactersToSample = 300; //Length of each sample to generate String generationInitialization = null; //Optional character initialization; a random character is used if null // Above is Used to 'prime' the LSTM with a character sequence to continue/complete. // Initialization characters must all be in CharacterIterator.getMinimalCharacterSet() by default Random rng = new Random(12345); //Get a DataSetIterator that handles vectorization of text into something we can use to train // our GravesLSTM network. CharacterIterator iter = getShakespeareIterator(miniBatchSize, exampleLength); int nOut = iter.totalOutcomes(); //Set up network configuration: MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).learningRate(0.1) .rmsDecay(0.95).seed(12345).regularization(true).l2(0.001).weightInit(WeightInit.XAVIER) .updater(Updater.RMSPROP).list() .layer(0,//from w ww . j ava 2s . c o m new GravesLSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize) .activation(Activation.TANH).build()) .layer(1, new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize).activation(Activation.TANH) .build()) .layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX) //MCXENT + softmax for classification .nIn(lstmLayerSize).nOut(nOut).build()) .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength) .tBPTTBackwardLength(tbpttLength).pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); net.setListeners(new ScoreIterationListener(System.out)); //Print the number of parameters in the network (and for each layer) Layer[] layers = net.getLayers(); int totalNumParams = 0; for (int i = 0; i < layers.length; i++) { int nParams = layers[i].numParams(); System.out.println("Number of parameters in layer " + i + ": " + nParams); totalNumParams += nParams; } System.out.println("Total number of network parameters: " + totalNumParams); //Do training, and then generate and print samples from network int miniBatchNumber = 0; for (int i = 0; i < numEpochs; i++) { System.out.println("Epoch number" + i); while (iter.hasNext()) { DataSet ds = iter.next(); net.fit(ds); System.out.println("Batch number " + miniBatchNumber); if (++miniBatchNumber % generateSamplesEveryNMinibatches == 0) { System.out.println("--------------------"); System.out.println("Completed " + miniBatchNumber + " minibatches of size " + miniBatchSize + "x" + exampleLength + " characters"); System.out.println("Sampling characters from network given initialization \"" + (generationInitialization == null ? "" : generationInitialization) + "\""); String[] samples = sampleCharactersFromNetwork(generationInitialization, net, iter, rng, nCharactersToSample, nSamplesToGenerate); for (int j = 0; j < samples.length; j++) { System.out.println("----- Sample " + j + " -----"); System.out.println(samples[j]); System.out.println(); } } } iter.reset(); //Reset iterator for another epoch } System.out.println("\n\nExample complete"); }
From source file:org.knime.ext.dl4j.base.nodes.learn.AbstractDLLearnerNodeModel.java
License:Open Source License
/** * Attempts to transfer weights from one {@link MultiLayerNetwork} to another. The weights will be transfered layer * by layer. Weights will only be transfered between intersecting layers of both networks. Assumes that the order * and type of the layers to transfer the weights is the same. * * @param from the network to get the weights from * @param to the network to transfer the weights to *//*from w w w . ja va2 s .co m*/ protected void transferWeights(final MultiLayerNetwork from, final MultiLayerNetwork to) { final List<INDArray> oldWeights = new ArrayList<>(); for (final org.deeplearning4j.nn.api.Layer layer : from.getLayers()) { oldWeights.add(layer.params()); } int i = 0; for (final org.deeplearning4j.nn.api.Layer layer : to.getLayers()) { if (i < oldWeights.size()) { try { layer.setParams(oldWeights.get(i)); logger.info("Successfully transfered weights from layer: " + (i + 1) + " (" + from.getLayers()[i].getClass().getName() + ") of old network to " + "layer: " + (i + 1) + " (" + layer.getClass().getName() + ") of new network"); } catch (final Exception e) { logger.warn("Could not transfer weights from layer: " + (i + 1) + " (" + from.getLayers()[i].getClass().getName() + ") of old network to " + "layer: " + (i + 1) + " (" + layer.getClass().getName() + ") of new network", e); logger.warn("Reason: " + e.getMessage(), e); } i++; } else { break; } } }