Example usage for org.deeplearning4j.nn.multilayer MultiLayerNetwork getLayer

List of usage examples for org.deeplearning4j.nn.multilayer MultiLayerNetwork getLayer

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.multilayer MultiLayerNetwork getLayer.

Prototype

public Layer getLayer(String name) 

Source Link

Usage

From source file:org.audiveris.omr.classifier.DeepClassifier.java

License:Open Source License

@Override
protected boolean isCompatible(MultiLayerNetwork model, Norms norms) {
    // Check input numbers for norms
    final int normsIn = norms.means.columns();

    if (normsIn != 1) {
        logger.warn("Incompatible norms count:{} expected:{}", normsIn, 1);

        return false;
    }//from w ww  . ja  v  a 2 s  .c  o  m

    // Check input numbers for model
    final org.deeplearning4j.nn.layers.convolution.ConvolutionLayer inputLayer = (org.deeplearning4j.nn.layers.convolution.ConvolutionLayer) model
            .getLayer(0);
    final org.deeplearning4j.nn.conf.layers.ConvolutionLayer confInputLayer = (org.deeplearning4j.nn.conf.layers.ConvolutionLayer) inputLayer
            .conf().getLayer();
    final int modelIn = confInputLayer.getNIn();

    if (modelIn != 1) {
        logger.warn("Incompatible features count:{} expected:{}", modelIn, 1);

        return false;
    }

    // Check output numbers for model
    final org.deeplearning4j.nn.layers.OutputLayer outputLayer = (org.deeplearning4j.nn.layers.OutputLayer) model
            .getOutputLayer();
    final org.deeplearning4j.nn.conf.layers.OutputLayer confOutputLayer = (org.deeplearning4j.nn.conf.layers.OutputLayer) outputLayer
            .conf().getLayer();
    final int modelOut = confOutputLayer.getNOut();

    if (modelOut != SHAPE_COUNT) {
        logger.warn("Incompatible shape count model:{} expected:{}", modelOut, SHAPE_COUNT);

        return false;
    }

    return true;
}

From source file:org.ensor.fftmusings.autoencoder.StackTrainer.java

public static void main(String[] args) throws IOException, Exception {

    MultiLayerNetwork pretrainedLayers[] = new MultiLayerNetwork[6];

    pretrainedLayers[0] = ModelSerializer.restoreMultiLayerNetwork("data/daa/model-1024-1200sparse0.01.nn");
    pretrainedLayers[1] = ModelSerializer.restoreMultiLayerNetwork("data/daa/model-1200-800sparse0.01.nn");
    pretrainedLayers[2] = ModelSerializer.restoreMultiLayerNetwork("data/daa/model-800-400sparse0.01.nn");
    pretrainedLayers[3] = ModelSerializer.restoreMultiLayerNetwork("data/daa/model-400-200sparse0.01.nn");
    pretrainedLayers[4] = ModelSerializer.restoreMultiLayerNetwork("data/daa/model-200-100sparse0.01.nn");

    NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
            .seed(System.currentTimeMillis()).iterations(1)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).weightInit(WeightInit.XAVIER)
            .updater(Updater.NESTEROVS).regularization(false).l1(0.000).learningRate(0.0001);
    //.learningRate(Double.parseDouble(args[0]));
    int layerNo = 0;

    NeuralNetConfiguration.ListBuilder listBuilder = builder.list()
            .layer(layerNo++, new RBM.Builder().nIn(1024).nOut(1200).activation(Activation.SIGMOID).build())
            .layer(layerNo++, new RBM.Builder().nIn(1200).nOut(800).activation(Activation.SIGMOID).build())
            .layer(layerNo++, new RBM.Builder().nIn(800).nOut(400).activation(Activation.SIGMOID).build())
            .layer(layerNo++, new RBM.Builder().nIn(400).nOut(200).activation(Activation.SIGMOID).build())
            .layer(layerNo++, new RBM.Builder().nIn(200).nOut(100).activation(Activation.SIGMOID).build())
            //                .layer(layerNo++, new RBM.Builder()
            //                        .nIn(100)
            //                        .nOut(50)
            //                        .activation(Activation.SIGMOID)
            //                        .build())
            //                .layer(layerNo++, new RBM.Builder()
            //                        .nIn(50)
            //                        .nOut(100)
            //                        .activation(Activation.SIGMOID)
            //                        .build())
            .layer(layerNo++, new RBM.Builder().nIn(100).nOut(200).activation(Activation.SIGMOID).build())
            .layer(layerNo++, new RBM.Builder().nIn(200).nOut(400).activation(Activation.SIGMOID).build())
            .layer(layerNo++, new RBM.Builder().nIn(400).nOut(800).activation(Activation.SIGMOID).build())
            .layer(layerNo++, new RBM.Builder().nIn(800).nOut(1200).activation(Activation.SIGMOID).build())
            .layer(layerNo++,/*from w  w  w .jav  a2s.  c  om*/
                    new OutputLayer.Builder().nIn(1200).nOut(1024).activation(Activation.IDENTITY)
                            .lossFunction(LossFunctions.LossFunction.L2).build())
            .pretrain(false).backprop(true);

    MultiLayerConfiguration conf = listBuilder.build();

    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();
    model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(System.out)));

    for (layerNo = 0; layerNo < 5; layerNo++) {
        model.getLayer(layerNo).setParam(PretrainParamInitializer.BIAS_KEY,
                pretrainedLayers[layerNo].getLayer(0).getParam(PretrainParamInitializer.BIAS_KEY));
        model.getLayer(layerNo).setParam(PretrainParamInitializer.WEIGHT_KEY,
                pretrainedLayers[layerNo].getLayer(0).getParam(PretrainParamInitializer.WEIGHT_KEY));

        model.getLayer(model.getnLayers() - layerNo - 1).setParam(PretrainParamInitializer.BIAS_KEY,
                pretrainedLayers[layerNo].getLayer(1).getParam(PretrainParamInitializer.BIAS_KEY));
        model.getLayer(model.getnLayers() - layerNo - 1).setParam(PretrainParamInitializer.WEIGHT_KEY,
                pretrainedLayers[layerNo].getLayer(1).getParam(PretrainParamInitializer.WEIGHT_KEY));
    }

    DataSetIterator iter = new FFTDataIterator(new Random(), 100, 1250, System.out);

    int epoch = 0;
    for (int i = 0; i < 300; i++) {
        model.fit(iter);
        iter.reset();
        evaluateModel(model, epoch);
        ModelSerializer.writeModel(model, "stack.rnn", true);
        epoch++;
    }
}