Example usage for org.deeplearning4j.nn.api Model paramTable

List of usage examples for org.deeplearning4j.nn.api Model paramTable

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.api Model paramTable.

Prototype

Map<String, INDArray> paramTable();

Source Link

Document

The param table

Usage

From source file:com.javafxpert.neuralnetviz.model.ModelListener.java

License:Apache License

@Override
public void iterationDone(Model model, int iteration) {
    this.multiLayerNetworkEnhanced = (MultiLayerNetworkEnhanced) model;
    if (iterations <= 0)
        iterations = 1;/*from  ww w .ja  v a  2  s . c  o m*/
    if (iterCount % iterations == 0) {
        invoke();
        if (iteration % iterations == 0) {
            System.out.println("In iterationDone(), iteration: " + iteration + ", iterations: " + iterations
                    + ", iterCount: " + iterCount + ", score: " + model.score());
            Map<String, Map> newGrad = new LinkedHashMap<>();
            try {
                Map<String, INDArray> grad = model.gradient().gradientForVariable();

                //log.warn("Starting report building...");

                if (meanMagHistoryParams.size() == 0) {
                    //Initialize:
                    int maxLayerIdx = -1;
                    for (String s : grad.keySet()) {
                        maxLayerIdx = Math.max(maxLayerIdx, indexFromString(s));
                    }
                    if (maxLayerIdx == -1)
                        maxLayerIdx = 0;
                    for (int i = 0; i <= maxLayerIdx; i++) {
                        meanMagHistoryParams.add(new LinkedHashMap<String, List<Double>>());
                        meanMagHistoryUpdates.add(new LinkedHashMap<String, List<Double>>());
                    }
                }

                for (Map.Entry<String, INDArray> entry : grad.entrySet()) {
                    String param = entry.getKey();
                    String newName;
                    if (Character.isDigit(param.charAt(0)))
                        newName = "param_" + param;
                    else
                        newName = param;

                    //log.warn("params newName: " + newName + " \n" + entry.getValue().dup());

                    /*
                    HistogramBin histogram = new HistogramBin.Builder(entry.getValue().dup())
                    .setBinCount(20)
                    .setRounding(6)
                    .build();
                    */

                    //newGrad.put(newName, histogram.getData());
                    //CSS identifier can't start with digit http://www.w3.org/TR/CSS21/syndata.html#value-def-identifier

                    int idx = indexFromString(param);
                    if (idx >= meanMagHistoryUpdates.size()) {
                        //log.info("Can't find idx for update ["+newName+"]");
                        meanMagHistoryUpdates.add(new LinkedHashMap<String, List<Double>>());
                    }

                    //Work out layer index:
                    Map<String, List<Double>> map = meanMagHistoryUpdates.get(idx);
                    List<Double> list = map.get(newName);
                    if (list == null) {
                        list = new ArrayList<>();
                        map.put(newName, list);
                    }
                    double meanMag = entry.getValue().norm1Number().doubleValue() / entry.getValue().length();
                    list.add(meanMag);

                }
            } catch (Exception e) {
                log.warn("Exception: " + e);
            }

            // Create instance of NeuralNetGraph to populate in multiple invocations of populateNeuralNetModel()
            neuralNetGraph = new NeuralNetGraph();
            curNodeId = 0; // NeuralNetNode instances have zero based node IDs

            //Process parameters: duplicate + calculate and store mean magnitudes
            Map<String, INDArray> params = model.paramTable();
            Map<String, Map> newParams = new LinkedHashMap<>();
            for (Map.Entry<String, INDArray> entry : params.entrySet()) {
                String param = entry.getKey();
                INDArray value = entry.getValue().dup();

                String newName; // TODO perhaps remove

                char firstChar = param.charAt(0);
                if (Character.isDigit(firstChar)) {
                    newName = "param_" + param;
                    //System.out.println("updates newName: " + newName + " \n" + entry.getValue().dup());

                    int layerNum = Character.getNumericValue(firstChar) + 1;
                    boolean containsWeights = false;

                    // param should take the form of 0_W or 0_b where first digit is layer number - 1
                    // Assumption: "W" entry appears before "b" entry for a given layer
                    if (param.length() == 3 && param.charAt(1) == '_'
                            && (param.charAt(2) == 'W' || param.charAt(2) == 'b')) {
                        containsWeights = param.charAt(2) == 'W';

                        // Populate NeuralNet* model classes
                        populateNeuralNetModel(neuralNetGraph, layerNum, containsWeights, value);
                    }
                } else {
                    newName = param;
                }

                /*
                HistogramBin histogram = new HistogramBin.Builder(entry.getValue().dup())
                .setBinCount(20)
                .setRounding(6)
                .build();
                newParams.put(newName, histogram.getData());
                //dup() because params might be a view
                */

                int idx = indexFromString(param);
                if (idx >= meanMagHistoryParams.size()) {
                    //log.info("Can't find idx for param ["+newName+"]");
                    meanMagHistoryParams.add(new LinkedHashMap<String, List<Double>>());
                }

                Map<String, List<Double>> map = meanMagHistoryParams.get(idx);
                List<Double> list = map.get(newName);
                if (list == null) {
                    list = new ArrayList<>();
                    map.put(newName, list);
                }
                double meanMag = entry.getValue().norm1Number().doubleValue() / entry.getValue().length();
                list.add(meanMag);
            }

            ObjectMapper mapper = new ObjectMapper();
            String jsonString = "";
            try {
                jsonString = mapper.writeValueAsString(neuralNetGraph);
            } catch (JsonProcessingException jpe) {
                System.out.println("Exception serializing neuralNetGraph: " + jpe);
            }
            //System.out.println("neuralNetGraph: \n\n" + jsonString);

            try {
                webSocketSession.sendMessage(new TextMessage(jsonString));
            } catch (IOException ioe) {
                ioe.printStackTrace();
            }
        }
    }
    iterCount++;
}