Example usage for org.deeplearning4j.eval Evaluation stats

List of usage examples for org.deeplearning4j.eval Evaluation stats

Introduction

In this page you can find the example usage for org.deeplearning4j.eval Evaluation stats.

Prototype

public String stats(boolean suppressWarnings) 

Source Link

Document

Method to obtain the classification report as a String

Usage

From source file:org.audiveris.omr.classifier.DeepClassifier.java

License:Open Source License

@SuppressWarnings("unchecked")
@Override/*from   www. j  av a 2 s . co  m*/
public void train(Collection<Sample> samples) {
    if (samples.isEmpty()) {
        logger.warn("No sample to retrain neural classifier");

        return;
    }

    // Shuffle the collection of samples
    final List<Sample> newSamples = new ArrayList<Sample>(samples);
    Collections.shuffle(newSamples);

    // Build raw dataset
    final DataSet dataSet = getRawDataSet(newSamples);

    // Record mean and standard deviation for *ALL* pixels
    final INDArray features = dataSet.getFeatures();
    logger.info("features rows:{} cols:{}", features.rows(), features.columns());

    Population pop = new Population();
    final int rows = features.rows();
    final int cols = features.columns();

    for (int r = 0; r < rows; r++) {
        INDArray row = features.getRow(r);

        for (int c = 0; c < cols; c++) {
            pop.includeValue(row.getDouble(c));
        }
    }

    logger.info("pop: {}", pop);

    INDArray mean = Nd4j.create(new double[] { pop.getMeanValue() });
    INDArray std = Nd4j.create(new double[] { pop.getStandardDeviation() + Nd4j.EPS_THRESHOLD });
    norms = new Norms(mean, std);

    logger.info("norms.means: {}", norms.means);
    logger.info("norms.stds: {}", norms.stds);

    // Normalize
    ///dataSet.normalizeZeroMeanZeroUnitVariance();
    normalize(features);

    logger.info("Training network...");

    final int epochs = getMaxEpochs();

    for (int epoch = 1; epoch <= epochs; epoch++) {
        epochStarted(epoch);

        model.fit(dataSet);

        // Evaluate
        logger.info("Epoch:{} evaluating on training set...", epoch);

        final List<String> names = Arrays.asList(ShapeSet.getPhysicalShapeNames());
        org.deeplearning4j.eval.Evaluation eval = new org.deeplearning4j.eval.Evaluation(names);
        INDArray guesses = model.output(dataSet.getFeatureMatrix());
        eval.eval(dataSet.getLabels(), guesses);
        logger.info(eval.stats(true));
    }

    // Store
    store(FILE_NAME);
}