Example usage for org.deeplearning4j.nn.api Model gradient

List of usage examples for org.deeplearning4j.nn.api Model gradient

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.api Model gradient.

Prototype

Gradient gradient();

Source Link

Document

Get the gradient.

Usage

From source file:com.javafxpert.neuralnetviz.model.ModelListener.java

License:Apache License

@Override
public void iterationDone(Model model, int iteration) {
    this.multiLayerNetworkEnhanced = (MultiLayerNetworkEnhanced) model;
    if (iterations <= 0)
        iterations = 1;//from   www. ja v  a 2 s. c  o m
    if (iterCount % iterations == 0) {
        invoke();
        if (iteration % iterations == 0) {
            System.out.println("In iterationDone(), iteration: " + iteration + ", iterations: " + iterations
                    + ", iterCount: " + iterCount + ", score: " + model.score());
            Map<String, Map> newGrad = new LinkedHashMap<>();
            try {
                Map<String, INDArray> grad = model.gradient().gradientForVariable();

                //log.warn("Starting report building...");

                if (meanMagHistoryParams.size() == 0) {
                    //Initialize:
                    int maxLayerIdx = -1;
                    for (String s : grad.keySet()) {
                        maxLayerIdx = Math.max(maxLayerIdx, indexFromString(s));
                    }
                    if (maxLayerIdx == -1)
                        maxLayerIdx = 0;
                    for (int i = 0; i <= maxLayerIdx; i++) {
                        meanMagHistoryParams.add(new LinkedHashMap<String, List<Double>>());
                        meanMagHistoryUpdates.add(new LinkedHashMap<String, List<Double>>());
                    }
                }

                for (Map.Entry<String, INDArray> entry : grad.entrySet()) {
                    String param = entry.getKey();
                    String newName;
                    if (Character.isDigit(param.charAt(0)))
                        newName = "param_" + param;
                    else
                        newName = param;

                    //log.warn("params newName: " + newName + " \n" + entry.getValue().dup());

                    /*
                    HistogramBin histogram = new HistogramBin.Builder(entry.getValue().dup())
                    .setBinCount(20)
                    .setRounding(6)
                    .build();
                    */

                    //newGrad.put(newName, histogram.getData());
                    //CSS identifier can't start with digit http://www.w3.org/TR/CSS21/syndata.html#value-def-identifier

                    int idx = indexFromString(param);
                    if (idx >= meanMagHistoryUpdates.size()) {
                        //log.info("Can't find idx for update ["+newName+"]");
                        meanMagHistoryUpdates.add(new LinkedHashMap<String, List<Double>>());
                    }

                    //Work out layer index:
                    Map<String, List<Double>> map = meanMagHistoryUpdates.get(idx);
                    List<Double> list = map.get(newName);
                    if (list == null) {
                        list = new ArrayList<>();
                        map.put(newName, list);
                    }
                    double meanMag = entry.getValue().norm1Number().doubleValue() / entry.getValue().length();
                    list.add(meanMag);

                }
            } catch (Exception e) {
                log.warn("Exception: " + e);
            }

            // Create instance of NeuralNetGraph to populate in multiple invocations of populateNeuralNetModel()
            neuralNetGraph = new NeuralNetGraph();
            curNodeId = 0; // NeuralNetNode instances have zero based node IDs

            //Process parameters: duplicate + calculate and store mean magnitudes
            Map<String, INDArray> params = model.paramTable();
            Map<String, Map> newParams = new LinkedHashMap<>();
            for (Map.Entry<String, INDArray> entry : params.entrySet()) {
                String param = entry.getKey();
                INDArray value = entry.getValue().dup();

                String newName; // TODO perhaps remove

                char firstChar = param.charAt(0);
                if (Character.isDigit(firstChar)) {
                    newName = "param_" + param;
                    //System.out.println("updates newName: " + newName + " \n" + entry.getValue().dup());

                    int layerNum = Character.getNumericValue(firstChar) + 1;
                    boolean containsWeights = false;

                    // param should take the form of 0_W or 0_b where first digit is layer number - 1
                    // Assumption: "W" entry appears before "b" entry for a given layer
                    if (param.length() == 3 && param.charAt(1) == '_'
                            && (param.charAt(2) == 'W' || param.charAt(2) == 'b')) {
                        containsWeights = param.charAt(2) == 'W';

                        // Populate NeuralNet* model classes
                        populateNeuralNetModel(neuralNetGraph, layerNum, containsWeights, value);
                    }
                } else {
                    newName = param;
                }

                /*
                HistogramBin histogram = new HistogramBin.Builder(entry.getValue().dup())
                .setBinCount(20)
                .setRounding(6)
                .build();
                newParams.put(newName, histogram.getData());
                //dup() because params might be a view
                */

                int idx = indexFromString(param);
                if (idx >= meanMagHistoryParams.size()) {
                    //log.info("Can't find idx for param ["+newName+"]");
                    meanMagHistoryParams.add(new LinkedHashMap<String, List<Double>>());
                }

                Map<String, List<Double>> map = meanMagHistoryParams.get(idx);
                List<Double> list = map.get(newName);
                if (list == null) {
                    list = new ArrayList<>();
                    map.put(newName, list);
                }
                double meanMag = entry.getValue().norm1Number().doubleValue() / entry.getValue().length();
                list.add(meanMag);
            }

            ObjectMapper mapper = new ObjectMapper();
            String jsonString = "";
            try {
                jsonString = mapper.writeValueAsString(neuralNetGraph);
            } catch (JsonProcessingException jpe) {
                System.out.println("Exception serializing neuralNetGraph: " + jpe);
            }
            //System.out.println("neuralNetGraph: \n\n" + jsonString);

            try {
                webSocketSession.sendMessage(new TextMessage(jsonString));
            } catch (IOException ioe) {
                ioe.printStackTrace();
            }
        }
    }
    iterCount++;
}