Example usage for org.deeplearning4j.nn.api Model score

List of usage examples for org.deeplearning4j.nn.api Model score

Introduction

In this page you can find the example usage for org.deeplearning4j.nn.api Model score.

Prototype

double score();

Source Link

Document

The score for the model

Usage

From source file:com.javafxpert.neuralnetviz.model.ModelListener.java

License:Apache License

@Override
public void iterationDone(Model model, int iteration) {
    this.multiLayerNetworkEnhanced = (MultiLayerNetworkEnhanced) model;
    if (iterations <= 0)
        iterations = 1;/*from  w w  w  .j  a  va  2  s. c o m*/
    if (iterCount % iterations == 0) {
        invoke();
        if (iteration % iterations == 0) {
            System.out.println("In iterationDone(), iteration: " + iteration + ", iterations: " + iterations
                    + ", iterCount: " + iterCount + ", score: " + model.score());
            Map<String, Map> newGrad = new LinkedHashMap<>();
            try {
                Map<String, INDArray> grad = model.gradient().gradientForVariable();

                //log.warn("Starting report building...");

                if (meanMagHistoryParams.size() == 0) {
                    //Initialize:
                    int maxLayerIdx = -1;
                    for (String s : grad.keySet()) {
                        maxLayerIdx = Math.max(maxLayerIdx, indexFromString(s));
                    }
                    if (maxLayerIdx == -1)
                        maxLayerIdx = 0;
                    for (int i = 0; i <= maxLayerIdx; i++) {
                        meanMagHistoryParams.add(new LinkedHashMap<String, List<Double>>());
                        meanMagHistoryUpdates.add(new LinkedHashMap<String, List<Double>>());
                    }
                }

                for (Map.Entry<String, INDArray> entry : grad.entrySet()) {
                    String param = entry.getKey();
                    String newName;
                    if (Character.isDigit(param.charAt(0)))
                        newName = "param_" + param;
                    else
                        newName = param;

                    //log.warn("params newName: " + newName + " \n" + entry.getValue().dup());

                    /*
                    HistogramBin histogram = new HistogramBin.Builder(entry.getValue().dup())
                    .setBinCount(20)
                    .setRounding(6)
                    .build();
                    */

                    //newGrad.put(newName, histogram.getData());
                    //CSS identifier can't start with digit http://www.w3.org/TR/CSS21/syndata.html#value-def-identifier

                    int idx = indexFromString(param);
                    if (idx >= meanMagHistoryUpdates.size()) {
                        //log.info("Can't find idx for update ["+newName+"]");
                        meanMagHistoryUpdates.add(new LinkedHashMap<String, List<Double>>());
                    }

                    //Work out layer index:
                    Map<String, List<Double>> map = meanMagHistoryUpdates.get(idx);
                    List<Double> list = map.get(newName);
                    if (list == null) {
                        list = new ArrayList<>();
                        map.put(newName, list);
                    }
                    double meanMag = entry.getValue().norm1Number().doubleValue() / entry.getValue().length();
                    list.add(meanMag);

                }
            } catch (Exception e) {
                log.warn("Exception: " + e);
            }

            // Create instance of NeuralNetGraph to populate in multiple invocations of populateNeuralNetModel()
            neuralNetGraph = new NeuralNetGraph();
            curNodeId = 0; // NeuralNetNode instances have zero based node IDs

            //Process parameters: duplicate + calculate and store mean magnitudes
            Map<String, INDArray> params = model.paramTable();
            Map<String, Map> newParams = new LinkedHashMap<>();
            for (Map.Entry<String, INDArray> entry : params.entrySet()) {
                String param = entry.getKey();
                INDArray value = entry.getValue().dup();

                String newName; // TODO perhaps remove

                char firstChar = param.charAt(0);
                if (Character.isDigit(firstChar)) {
                    newName = "param_" + param;
                    //System.out.println("updates newName: " + newName + " \n" + entry.getValue().dup());

                    int layerNum = Character.getNumericValue(firstChar) + 1;
                    boolean containsWeights = false;

                    // param should take the form of 0_W or 0_b where first digit is layer number - 1
                    // Assumption: "W" entry appears before "b" entry for a given layer
                    if (param.length() == 3 && param.charAt(1) == '_'
                            && (param.charAt(2) == 'W' || param.charAt(2) == 'b')) {
                        containsWeights = param.charAt(2) == 'W';

                        // Populate NeuralNet* model classes
                        populateNeuralNetModel(neuralNetGraph, layerNum, containsWeights, value);
                    }
                } else {
                    newName = param;
                }

                /*
                HistogramBin histogram = new HistogramBin.Builder(entry.getValue().dup())
                .setBinCount(20)
                .setRounding(6)
                .build();
                newParams.put(newName, histogram.getData());
                //dup() because params might be a view
                */

                int idx = indexFromString(param);
                if (idx >= meanMagHistoryParams.size()) {
                    //log.info("Can't find idx for param ["+newName+"]");
                    meanMagHistoryParams.add(new LinkedHashMap<String, List<Double>>());
                }

                Map<String, List<Double>> map = meanMagHistoryParams.get(idx);
                List<Double> list = map.get(newName);
                if (list == null) {
                    list = new ArrayList<>();
                    map.put(newName, list);
                }
                double meanMag = entry.getValue().norm1Number().doubleValue() / entry.getValue().length();
                list.add(meanMag);
            }

            ObjectMapper mapper = new ObjectMapper();
            String jsonString = "";
            try {
                jsonString = mapper.writeValueAsString(neuralNetGraph);
            } catch (JsonProcessingException jpe) {
                System.out.println("Exception serializing neuralNetGraph: " + jpe);
            }
            //System.out.println("neuralNetGraph: \n\n" + jsonString);

            try {
                webSocketSession.sendMessage(new TextMessage(jsonString));
            } catch (IOException ioe) {
                ioe.printStackTrace();
            }
        }
    }
    iterCount++;
}

From source file:org.audiveris.omr.classifier.ui.TrainingPanel.java

License:Open Source License

@Override
public void iterationDone(Model model, int iteration) {
    iterCount++;/*from www . ja v  a2  s  . c  om*/

    if ((iterCount % constants.listenerPeriod.getValue()) == 0) {
        invoke();

        final double score = model.score();
        final int count = (int) iterCount;
        logger.info(String.format("Score at iteration %d is %.5f", count, score));
        display(epoch, count, score);
    }
}

From source file:org.ensor.fftmusings.atrain.ScoreIterationListener.java

@Override
public void iterationDone(Model model, int iteration) {
    invoke();//from ww  w  .ja  v a  2s  . c om
    if (iterCount % 32 != 1)
        return;
    double result = model.score();
    long currentTime = System.currentTimeMillis();
    logWriter.println("Score at iteration " + iterCount + " is " + result + " took "
            + (currentTime - lastSampleTime) + " ms");
    lastSampleTime = currentTime;
}

From source file:org.ensor.fftmusings.autoencoder.ScoreIterationListener.java

@Override
public void iterationDone(Model model, int iteration) {
    invoke();/*  w w w . ja v  a 2s  . c  om*/
    double result = model.score();
    long currentTime = System.currentTimeMillis();
    //if ((iterCount % 4) == 0) {
    logWriter.println("Score at iteration " + iterCount + " is " + result + " took "
            + (currentTime - lastSampleTime) + " ms");
    //}
    lastSampleTime = currentTime;
}

From source file:org.ensor.fftmusings.rnn.qft.ScoreIterationListener.java

@Override
public void iterationDone(Model model, int iteration) {
    invoke();/*from   ww w.  ja  v  a  2  s  .  c  o  m*/
    double result = model.score();
    System.out.println("Score at iteration " + iterCount + " is " + result);
    iterCount++;
}

From source file:org.knime.ext.dl4j.base.nodes.learn.view.UpdateLearnerViewIterationListener.java

License:Open Source License

@Override
public void iterationDone(final Model model, final int iteration) {
    invoke();/*from  w  ww. j av a2 s  . c om*/
    final double result = model.score();
    //pass the current score to the view
    m_nodeModel.passObjToView(new Double(result));
    //update score in node model
    m_nodeModel.setScore(result);
}

From source file:regression.reinforce.MemoIterationListener.java

License:Apache License

@Override
public void iterationDone(Model model, int iteration) {

    if (printIterations <= 0)
        printIterations = 1;//from w w  w  . j  a  v a 2 s  .c o  m
    if (iteration % printIterations == 0) {
        invoke();
        double result = model.score();
        log.info("Score at iteration " + iteration + " is " + result);
        buffer.add(result);
        //todo: absolument faire qqch pour closer le writer
    }
}

From source file:weka.dl4j.FileIterationListener.java

License:Open Source License

/**
 * Method that gets called when an iteration is done.
 *
 * @param model the model to operate with
 * @param epoch the epoch number/*from   w w  w .  j av a 2s  .com*/
 */
@Override
public void iterationDone(Model model, int epoch) {

    lossesPerEpoch.add(model.score());
    if (lossesPerEpoch.size() == m_numMiniBatches) {
        // calculate mean
        double mean = 0;
        for (double val : lossesPerEpoch) {
            mean += val;
        }
        mean = mean / lossesPerEpoch.size();
        m_pw.write(mean + "\n");
        m_pw.flush();
        lossesPerEpoch.clear();
    }
}