org.deeplearning4j.eval.ConfusionMatrix.java Source code

Java tutorial

Introduction

Here is the source code for org.deeplearning4j.eval.ConfusionMatrix.java

Source

/*-
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.eval;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;

import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ConfusionMatrix<T extends Comparable<? super T>> implements Serializable {
    private Map<T, Multiset<T>> matrix;
    private List<T> classes;

    /**
     * Creates an empty confusion Matrix
     */
    public ConfusionMatrix(List<T> classes) {
        this.matrix = new HashMap<>();
        this.classes = classes;
    }

    public ConfusionMatrix() {
    }

    /**
     * Creates a new ConfusionMatrix initialized with the contents of another ConfusionMatrix.
     */
    public ConfusionMatrix(ConfusionMatrix<T> other) {
        this(other.getClasses());
        this.add(other);
    }

    /**
     * Increments the entry specified by actual and predicted by one.
     */
    public void add(T actual, T predicted) {
        add(actual, predicted, 1);
    }

    /**
     * Increments the entry specified by actual and predicted by count.
     */
    public void add(T actual, T predicted, int count) {
        if (matrix.containsKey(actual)) {
            matrix.get(actual).add(predicted, count);
        } else {
            Multiset<T> counts = HashMultiset.create();
            counts.add(predicted, count);
            matrix.put(actual, counts);
        }
    }

    /**
     * Adds the entries from another confusion matrix to this one.
     */
    public void add(ConfusionMatrix<T> other) {
        for (T actual : other.matrix.keySet()) {
            Multiset<T> counts = other.matrix.get(actual);
            for (T predicted : counts.elementSet()) {
                int count = counts.count(predicted);
                this.add(actual, predicted, count);
            }
        }
    }

    /**
     * Gives the applyTransformToDestination of all classes in the confusion matrix.
     */
    public List<T> getClasses() {
        return classes;
    }

    /**
     * Gives the count of the number of times the "predicted" class was predicted for the "actual"
     * class.
     */
    public int getCount(T actual, T predicted) {
        if (!matrix.containsKey(actual)) {
            return 0;
        } else {
            return matrix.get(actual).count(predicted);
        }
    }

    /**
     * Computes the total number of times the class was predicted by the classifier.
     */
    public int getPredictedTotal(T predicted) {
        int total = 0;
        for (T actual : classes) {
            total += getCount(actual, predicted);
        }
        return total;
    }

    /**
     * Computes the total number of times the class actually appeared in the data.
     */
    public int getActualTotal(T actual) {
        if (!matrix.containsKey(actual)) {
            return 0;
        } else {
            int total = 0;
            for (T elem : matrix.get(actual).elementSet()) {
                total += matrix.get(actual).count(elem);
            }
            return total;
        }
    }

    @Override
    public String toString() {
        return matrix.toString();
    }

    /**
     * Outputs the ConfusionMatrix as comma-separated values for easy import into spreadsheets
     */
    public String toCSV() {
        StringBuilder builder = new StringBuilder();

        // Header Row
        builder.append(",,Predicted Class,\n");

        // Predicted Classes Header Row
        builder.append(",,");
        for (T predicted : classes) {
            builder.append(String.format("%s,", predicted));
        }
        builder.append("Total\n");

        // Data Rows
        String firstColumnLabel = "Actual Class,";
        for (T actual : classes) {
            builder.append(firstColumnLabel);
            firstColumnLabel = ",";
            builder.append(String.format("%s,", actual));

            for (T predicted : classes) {
                builder.append(getCount(actual, predicted));
                builder.append(",");
            }
            // Actual Class Totals Column
            builder.append(getActualTotal(actual));
            builder.append("\n");
        }

        // Predicted Class Totals Row
        builder.append(",Total,");
        for (T predicted : classes) {
            builder.append(getPredictedTotal(predicted));
            builder.append(",");
        }
        builder.append("\n");

        return builder.toString();
    }

    /**
     * Outputs Confusion Matrix in an HTML table. Cascading Style Sheets (CSS) can control the table's
     * appearance by defining the empty-space, actual-count-header, predicted-class-header, and
     * count-element classes. For example
     *
     * @return html string
     */
    public String toHTML() {
        StringBuilder builder = new StringBuilder();

        int numClasses = classes.size();
        // Header Row
        builder.append("<table>\n");
        builder.append("<tr><th class=\"empty-space\" colspan=\"2\" rowspan=\"2\">");
        builder.append(String.format(
                "<th class=\"predicted-class-header\" colspan=\"%d\">Predicted Class</th></tr>%n", numClasses + 1));

        // Predicted Classes Header Row
        builder.append("<tr>");
        // builder.append("<th></th><th></th>");
        for (T predicted : classes) {
            builder.append("<th class=\"predicted-class-header\">");
            builder.append(predicted);
            builder.append("</th>");
        }
        builder.append("<th class=\"predicted-class-header\">Total</th>");
        builder.append("</tr>\n");

        // Data Rows
        String firstColumnLabel = String
                .format("<tr><th class=\"actual-class-header\" rowspan=\"%d\">Actual Class</th>", numClasses + 1);
        for (T actual : classes) {
            builder.append(firstColumnLabel);
            firstColumnLabel = "<tr>";
            builder.append(String.format("<th class=\"actual-class-header\" >%s</th>", actual));

            for (T predicted : classes) {
                builder.append("<td class=\"count-element\">");
                builder.append(getCount(actual, predicted));
                builder.append("</td>");
            }

            // Actual Class Totals Column
            builder.append("<td class=\"count-element\">");
            builder.append(getActualTotal(actual));
            builder.append("</td>");
            builder.append("</tr>\n");
        }

        // Predicted Class Totals Row
        builder.append("<tr><th class=\"actual-class-header\">Total</th>");
        for (T predicted : classes) {
            builder.append("<td class=\"count-element\">");
            builder.append(getPredictedTotal(predicted));
            builder.append("</td>");
        }
        builder.append("<td class=\"empty-space\"></td>\n");
        builder.append("</tr>\n");
        builder.append("</table>\n");

        return builder.toString();
    }

    @Override
    public boolean equals(Object o) {
        if (!(o instanceof ConfusionMatrix))
            return false;
        ConfusionMatrix<?> c = (ConfusionMatrix<?>) o;
        return matrix.equals(c.matrix) && classes.equals(c.classes);
    }

    @Override
    public int hashCode() {
        int result = 17;
        result = 31 * result + (matrix == null ? 0 : matrix.hashCode());
        result = 31 * result + (classes == null ? 0 : classes.hashCode());
        return result;
    }

    public static void main(String[] args) {
        ConfusionMatrix<String> confusionMatrix = new ConfusionMatrix<>(Arrays.asList("a", "b", "c"));

        confusionMatrix.add("a", "a", 88);
        confusionMatrix.add("a", "b", 10);
        confusionMatrix.add("b", "a", 14);
        confusionMatrix.add("b", "b", 40);
        confusionMatrix.add("b", "c", 6);
        confusionMatrix.add("c", "a", 18);
        confusionMatrix.add("c", "b", 10);
        confusionMatrix.add("c", "c", 12);

        ConfusionMatrix<String> confusionMatrix2 = new ConfusionMatrix<>(confusionMatrix);
        confusionMatrix2.add(confusionMatrix);
        System.out.println(confusionMatrix2.toHTML());
        System.out.println(confusionMatrix2.toCSV());
    }
}