game.plugins.metrics.StandardClassificationMetrics.java Source code

Java tutorial

Introduction

Here is the source code for game.plugins.metrics.StandardClassificationMetrics.java

Source

/*******************************************************************************
 * Copyright (c) 2012 Emanuele Tamponi.
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the GNU Public License v3.0
 * which accompanies this distribution, and is available at
 * http://www.gnu.org/licenses/gpl.html
 * 
 * Contributors:
 *     Emanuele Tamponi - initial API and implementation
 ******************************************************************************/
package game.plugins.metrics;

import game.core.Dataset;
import game.core.Dataset.SampleIterator;
import game.core.LabeledMatrix;
import game.core.Metrics;
import game.core.Result;
import game.core.Sample;
import game.core.blocks.Decoder;
import game.core.experiments.ClassificationResult;
import game.plugins.valuetemplates.LabelTemplate;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.math3.linear.RealVector;

import com.esotericsoftware.reflectasm.MethodAccess;
import com.ios.IList;

public class StandardClassificationMetrics extends Metrics<ClassificationResult> {

    private IList<String> labels;
    private List<Double> truePositives, falseNegatives;
    private List<Double> falsePositives, trueNegatives;

    @Override
    protected void prepareForEvaluation(ClassificationResult result) {
        truePositives = new ArrayList<>();
        falsePositives = new ArrayList<>();
        falseNegatives = new ArrayList<>();
        trueNegatives = new ArrayList<>();

        Dataset dataset = result.classifiedDataset;
        Decoder decoder = result.trainedClassifier.decoder;

        labels = dataset.getTemplate().targetTemplate.getSingleton(LabelTemplate.class).labels.copy();
        for (int k = 0; k < labels.size(); k++) {
            truePositives.add(0.0);
            falsePositives.add(0.0);
            falseNegatives.add(0.0);
            trueNegatives.add(0.0);
        }

        SampleIterator it = dataset.sampleIterator(null, null, decoder);
        while (it.hasNext()) {
            Sample sample = it.next();
            String trueLabel = (String) sample.getTarget().get();
            String predLabel = (String) sample.getDecodedTarget().get();
            int trueIndex = labels.indexOf(trueLabel);
            int predIndex = labels.indexOf(predLabel);

            if (predLabel.equals(trueLabel)) {
                // It is a "true positive" for the trueIndex and a "true negative" for ALL other indices
                for (int index = 0; index < labels.size(); index++) {
                    if (index == trueIndex)
                        truePositives.set(index, truePositives.get(index) + 1);
                    else
                        trueNegatives.set(index, trueNegatives.get(index) + 1);
                }
            } else {
                // It is a "false positive" for the predIndex and a "false negative" for the true index
                falsePositives.set(predIndex, falsePositives.get(predIndex) + 1);
                falseNegatives.set(trueIndex, falseNegatives.get(trueIndex) + 1);
            }
        }
    }

    private double evaluateWeightedAverage(RealVector row) {
        double ret = 0;
        double sum = 0;

        for (int i = 0; i < row.getDimension() - 1; i++) {
            double w = truePositives.get(i) + falseNegatives.get(i);
            ret += row.getEntry(i) * w;
            sum += w;
        }
        ret = sum == 0 ? 0 : ret / sum;

        return ret;
    }

    private static final String[] rows = { "Precision", "Recall", "Accuracy", "FScore", "Matthews" };

    protected LabeledMatrix evaluateMetrics(ClassificationResult result) {
        IList<String> labels = this.labels.copy();
        LabeledMatrix matrix = new LabeledMatrix(rows.length, labels.size() + 1);

        MethodAccess access = MethodAccess.get(StandardClassificationMetrics.class);
        for (int row = 0; row < rows.length; row++) {
            matrix.getRowLabels().add(rows[row]);
            access.invoke(this, "evaluate" + rows[row], matrix, row, result);
            matrix.setEntry(row, labels.size(), evaluateWeightedAverage(matrix.getRowVector(row)));
        }

        labels.add("w. avg");
        matrix.setColumnLabels(labels.toArray(new String[labels.size()]));

        return matrix;
    }

    protected void evaluatePrecision(LabeledMatrix matrix, int row, ClassificationResult result) {
        for (int i = 0; i < labels.size(); i++) {
            double TP = truePositives.get(i);
            double FP = falsePositives.get(i);
            if (TP + FP != 0) {
                matrix.setEntry(row, i, TP / (FP + TP));
            } else {
                matrix.setEntry(row, i, 0);
            }
        }
    }

    protected void evaluateRecall(LabeledMatrix matrix, int row, ClassificationResult result) {
        for (int i = 0; i < labels.size(); i++) {
            double TP = truePositives.get(i);
            double FN = falseNegatives.get(i);
            if (TP + FN == 0) {
                matrix.setEntry(row, i, 0);
            } else {
                matrix.setEntry(row, i, TP / (FN + TP));
            }
        }
    }

    protected void evaluateAccuracy(LabeledMatrix matrix, int row, ClassificationResult result) {
        for (int i = 0; i < labels.size(); i++) {
            double TP = truePositives.get(i);
            double TN = trueNegatives.get(i);
            double FN = falseNegatives.get(i);
            double FP = falsePositives.get(i);
            double den = TP + TN + FP + FN;
            matrix.setEntry(row, i, den == 0 ? 0 : (TP + TN) / den);
        }
    }

    protected void evaluateFScore(LabeledMatrix matrix, int row, ClassificationResult result) {
        for (int i = 0; i < labels.size(); i++) {
            double precision = matrix.getEntry(0, i);
            double recall = matrix.getEntry(1, i);
            if (precision + recall == 0) {
                matrix.setEntry(row, i, 0);
            } else {
                matrix.setEntry(row, i, 2 * (precision * recall) / (precision + recall));
            }
        }
    }

    protected void evaluateMatthews(LabeledMatrix matrix, int row, ClassificationResult result) {
        for (int i = 0; i < labels.size(); i++) {
            double TP = truePositives.get(i);
            double FN = falseNegatives.get(i);
            double FP = falsePositives.get(i);
            double TN = trueNegatives.get(i);

            double P = TP + FP;
            double N = FN + TN;
            double T = TP + FN;
            double F = FP + TN;
            double den = Math.sqrt(P * N * T * F);

            if (den == 0)
                matrix.setEntry(row, i, 0);
            else
                matrix.setEntry(row, i, (TP * TN - FP * FN) / den);
        }
    }

    public double getAccuracy() {
        return this.statistics.getMatrix()[2][labels.size()].getMean();
    }

    public String compatibilityError(Result result) {
        return result instanceof ClassificationResult ? null : "compatible with ClassificationResult only";
    }

}