List of usage examples for org.deeplearning4j.ui.stats StatsListener StatsListener
public StatsListener(StatsStorageRouter router, int listenerFrequency)
From source file:aiLogicImplementation.RNNBasic.java
License:Apache License
public void create(StatsStorage statsStorage) { NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder(); builder.iterations(1000);/*from ww w .jav a2 s. c o m*/ builder.learningRate(0.01); builder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); builder.seed(123); builder.biasInit(0.2); builder.miniBatch(true); // builder.updater(Updater.NESTEROVS); builder.updater(new Nesterovs(0.01)); builder.weightInit(WeightInit.XAVIER); builder.regularization(true); builder.l2(0.001); // builder.gradientNormalization(GradientNormalization.ClipL2PerParamType); // builder.gradientNormalizationThreshold(0.5); ListBuilder listBuilder = builder.list(); for (int i = 0; i < HIDDEN_LAYER_CONT; i++) { GravesLSTM.Builder hiddenLayerBuilder = new GravesLSTM.Builder(); hiddenLayerBuilder.nIn(i == 0 ? NUMBER_OF_FEATURE_INPUT : HIDDEN_LAYER_WIDTH); hiddenLayerBuilder.nOut(HIDDEN_LAYER_WIDTH); if (i == 0) { hiddenLayerBuilder.activation(Activation.TANH); } else { hiddenLayerBuilder.activation(Activation.TANH); } listBuilder.layer(i, hiddenLayerBuilder.build()); } // RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.XENT); RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.L2); // RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.MCXENT); // outputLayerBuilder.activation(Activation.SIGMOID); outputLayerBuilder.activation(Activation.RELU); // outputLayerBuilder.activation(Activation.SOFTMAX); outputLayerBuilder.nIn(HIDDEN_LAYER_WIDTH); outputLayerBuilder.nOut(NUMBER_OF_FEATURE_OUTPUT); outputLayerBuilder.weightInit(WeightInit.XAVIER); // outputLayerBuilder.dist(new UniformDistribution(0, 1)); listBuilder.layer(HIDDEN_LAYER_CONT, outputLayerBuilder.build()); // listBuilder.backpropType(BackpropType.TruncatedBPTT); // listBuilder.tBPTTForwardLength(tbpttLength); // listBuilder.tBPTTBackwardLength(tbpttLength); listBuilder.pretrain(false); listBuilder.backprop(true); // create network MultiLayerConfiguration conf = listBuilder.build(); net = new MultiLayerNetwork(conf); net.init(); // net.setListeners(new ScoreIterationListener(configuration.getListenerFrequency())); //Then add the StatsListener to collect this information from the network, as it trains List list = new ArrayList(); list.add(new StatsListener(statsStorage, configuration.getListenerFrequency())); list.add(new ScoreIterationListener(configuration.getListenerFrequency())); net.setListeners(list); /* * CREATE OUR TRAINING DATA */ input = Nd4j.zeros(maze.length * maze.length * hibertMaze.getHibertMazeGUI().getAlPlayers().size() * hibertMaze.getHibertMazeGUI().getAlFoods().size(), NUMBER_OF_FEATURE_INPUT); labels = Nd4j.zeros(maze.length * maze.length * hibertMaze.getHibertMazeGUI().getAlPlayers().size() * hibertMaze.getHibertMazeGUI().getAlFoods().size(), NUMBER_OF_FEATURE_OUTPUT); positionMoving = new PositionMoving(maze); int counter = 0; for (int i = 0; i < maze.length; i++) { for (int j = 0; j < maze[i].length; j++) { int mazeFieldValue = maze[i][j]; // input neuron for current-char is 1 at "samplePos" for (Player player : hibertMaze.getHibertMazeGUI().getAlPlayers()) { for (Food food : hibertMaze.getHibertMazeGUI().getAlFoods()) { Position position = new Position(i, j); input.putScalar(counter, 0, i); input.putScalar(counter, 1, j); input.putScalar(counter, 2, mazeFieldValue); input.putScalar(counter, 3, player.getPlayerNumber()); input.putScalar(counter, 4, i); input.putScalar(counter, 5, j); input.putScalar(counter, 6, food.getFoodNumber()); input.putScalar(counter, 7, food.getPosition().getX()); input.putScalar(counter, 8, food.getPosition().getY()); input.putScalar(counter, 9, positionMoving.moveNorth(position) ? 1.00d : 0.00d); input.putScalar(counter, 10, positionMoving.moveSouth(position) ? 1.00d : 0.00d); input.putScalar(counter, 11, positionMoving.moveEast(position) ? 1.00d : 0.00d); input.putScalar(counter, 12, positionMoving.moveWest(position) ? 1.00d : 0.00d); input.putScalar(counter, 13, player.getPosition().getX() == food.getPosition().getX() && player.getPosition().getY() == food.getPosition().getY() ? 1.00d : 0.00d); labels.putScalar(counter, 0, positionMoving.moveNorth(position) ? 1.00d : 0.00d); labels.putScalar(counter, 1, positionMoving.moveSouth(position) ? 1.00d : 0.00d); labels.putScalar(counter, 2, positionMoving.moveEast(position) ? 1.00d : 0.00d); labels.putScalar(counter, 3, positionMoving.moveWest(position) ? 1.00d : 0.00d); counter = counter + 1; } } } } trainingData = new DataSet(input, labels); }