de.tudarmstadt.ukp.dkpro.argumentation.sequence.report.ConfusionMatrixTools.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.dkpro.argumentation.sequence.report.ConfusionMatrixTools.java

Source

/*
 * Copyright 2015
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package de.tudarmstadt.ukp.dkpro.argumentation.sequence.report;

import de.tudarmstadt.ukp.dkpro.tc.svmhmm.util.ConfusionMatrix;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.StringUtils;

import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.*;

/**
 * @author Ivan Habernal
 */
public class ConfusionMatrixTools {

    private static final String GLUE = "\t";

    public static ConfusionMatrix tokenLevelPredictionsToConfusionMatrix(File predictionsFile) throws IOException {
        ConfusionMatrix cm = new ConfusionMatrix();

        CSVParser csvParser = new CSVParser(new FileReader(predictionsFile),
                CSVFormat.DEFAULT.withCommentMarker('#'));

        for (CSVRecord csvRecord : csvParser) {
            // update confusion matrix
            cm.increaseValue(csvRecord.get(0), csvRecord.get(1));
        }

        return cm;
    }

    public static String prettyPrintConfusionMatrixResults(ConfusionMatrix cm) {

        cm.printNiceResults();

        String f = "%.3f";

        List<String> header = new ArrayList<>();
        List<String> row = new ArrayList<>();

        header.add("Macro F1");
        header.add("Accuracy");
        header.add("Acc CI@95");

        row.add(String.format(Locale.ENGLISH, f, cm.getMacroFMeasure()));
        row.add(String.format(Locale.ENGLISH, f, cm.getAccuracy()));
        row.add(String.format(Locale.ENGLISH, f, cm.getConfidence95Accuracy()));

        Map<String, Double> precisionForLabels = cm.getPrecisionForLabels();
        Map<String, Double> recallForLabels = cm.getRecallForLabels();
        Map<String, Double> fMForLabels = cm.getFMeasureForLabels();

        SortedSet<String> labels = new TreeSet<>(precisionForLabels.keySet());

        for (String label : labels) {
            header.add(label + " P");
            row.add(String.format(Locale.ENGLISH, f, precisionForLabels.get(label)));

            header.add(label + " R");
            row.add(String.format(Locale.ENGLISH, f, recallForLabels.get(label)));

            header.add(label + " F1");
            row.add(String.format(Locale.ENGLISH, f, fMForLabels.get(label)));
        }

        return StringUtils.join(header, GLUE) + "\n" + StringUtils.join(row, GLUE);
    }

    public static void generateNiceTable(File predictionsFile) throws IOException {
        ConfusionMatrix cm = tokenLevelPredictionsToConfusionMatrix(predictionsFile);

        File outFile = new File(predictionsFile.getParent(), "niceResults.csv");

        FileUtils.writeStringToFile(outFile, prettyPrintConfusionMatrixResults(cm));

        System.out.println("Writing " + outFile);
    }

    public static void main(String[] args) throws IOException {
        String path = args[0];
        for (File file : FileUtils.listFiles(new File(path), new String[] { "csv" }, true)) {
            if (file.getName().startsWith("tokenLevelPredictions")) {
                generateNiceTable(file);
            }
        }

    }
}