com.joliciel.talismane.stats.FScoreCalculator.java Source code

Java tutorial

Introduction

Here is the source code for com.joliciel.talismane.stats.FScoreCalculator.java

Source

///////////////////////////////////////////////////////////////////////////////
//Copyright (C) 2012 Assaf Urieli
//
//This file is part of Talismane.
//
//Talismane is free software: you can redistribute it and/or modify
//it under the terms of the GNU Affero General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//Talismane is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//GNU Affero General Public License for more details.
//
//You should have received a copy of the GNU Affero General Public License
//along with Talismane.  If not, see <http://www.gnu.org/licenses/>.
//////////////////////////////////////////////////////////////////////////////
package com.joliciel.talismane.stats;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import java.util.Set;
import java.util.TreeSet;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math.stat.descriptive.DescriptiveStatistics;

import com.joliciel.talismane.utils.CSVFormatter;
import com.joliciel.talismane.utils.JolicielException;

/**
 * Calculates the f-score for a given classification task.
 * @author Assaf Urieli
 *
 */
public class FScoreCalculator<E> {
    private static final Log LOG = LogFactory.getLog(FScoreCalculator.class);
    private static final CSVFormatter CSV = new CSVFormatter();
    Map<E, Map<E, Integer>> falsePositives = new HashMap<E, Map<E, Integer>>();
    Map<E, Map<E, Integer>> falseNegatives = new HashMap<E, Map<E, Integer>>();
    Map<E, Integer> truePositiveCounts = new HashMap<E, Integer>();
    Map<E, Integer> falsePositiveCounts = new HashMap<E, Integer>();
    Map<E, Integer> falseNegativeCounts = new HashMap<E, Integer>();

    Set<E> outcomeSet = new TreeSet<E>();

    Map<E, Double> precisions = new HashMap<E, Double>();
    Map<E, Double> recalls = new HashMap<E, Double>();
    Map<E, Double> fScores = new HashMap<E, Double>();

    int testCount = 0;
    double totalPrecision = 0.0;
    double totalRecall = 0.0;
    double totalFScore = 0.0;

    int totalTruePositiveCount = 0;
    int totalFalsePositiveCount = 0;
    int totalFalseNegativeCount = 0;

    boolean updatedSinceLastEval = false;

    Object label = null;

    public FScoreCalculator(Object label) {
        this.label = label;
    }

    public FScoreCalculator() {

    }

    /**
     * Increment this f-score by a given expected value and guessed value.
     * @param expected
     * @param guessed
     */
    public void increment(E expected, E guessed) {
        int pairCount = 1;
        Map<E, Integer> falsePositivesForGuessed = falsePositives.get(guessed);
        if (falsePositivesForGuessed == null) {
            falsePositivesForGuessed = new HashMap<E, Integer>();
            falsePositives.put(guessed, falsePositivesForGuessed);
        }
        Integer pairCountObj = falsePositivesForGuessed.get(expected);
        if (pairCountObj != null)
            pairCount = pairCountObj.intValue() + 1;
        falsePositivesForGuessed.put(expected, pairCount);

        pairCount = 1;
        Map<E, Integer> falseNegativesForExpected = falseNegatives.get(expected);
        if (falseNegativesForExpected == null) {
            falseNegativesForExpected = new HashMap<E, Integer>();
            falseNegatives.put(expected, falseNegativesForExpected);
        }
        pairCountObj = falseNegativesForExpected.get(guessed);
        if (pairCountObj != null)
            pairCount = pairCountObj.intValue() + 1;
        falseNegativesForExpected.put(guessed, pairCount);

        if (expected.equals(guessed)) {
            int truePositiveCount = 1;
            Integer truePositiveCountObj = truePositiveCounts.get(expected);
            if (truePositiveCountObj != null)
                truePositiveCount = truePositiveCountObj.intValue() + 1;
            truePositiveCounts.put(expected, truePositiveCount);
        } else {
            // we didn't guess correctly that this was an X
            int falseNegativeCount = 1;
            Integer falseNegativeCountObj = falseNegativeCounts.get(expected);
            if (falseNegativeCountObj != null)
                falseNegativeCount = falseNegativeCountObj.intValue() + 1;
            falseNegativeCounts.put(expected, falseNegativeCount);

            // we guessed that this was a Y, when it wasn't
            int falsePositiveCount = 1;
            Integer falsePositiveCountObj = falsePositiveCounts.get(guessed);
            if (falsePositiveCountObj != null)
                falsePositiveCount = falsePositiveCountObj.intValue() + 1;
            falsePositiveCounts.put(guessed, falsePositiveCount);
        }
        outcomeSet.add(guessed);
        outcomeSet.add(expected);
        testCount++;
        updatedSinceLastEval = true;
    }

    void evaluate() {
        if (updatedSinceLastEval) {
            LOG.info("###F-score calculations " + (label == null ? "" : " for " + label.toString()));
            precisions = new HashMap<E, Double>();
            recalls = new HashMap<E, Double>();
            fScores = new HashMap<E, Double>();

            for (E outcome : outcomeSet) {
                LOG.debug("Outcome: " + outcome);
                Integer truePositiveCountObj = truePositiveCounts.get(outcome);
                Integer falsePositiveCountObj = falsePositiveCounts.get(outcome);
                Integer falseNegativeCountObj = falseNegativeCounts.get(outcome);
                double truePositiveCount = truePositiveCountObj != null ? truePositiveCountObj.doubleValue() : 0.0;
                double falsePositiveCount = falsePositiveCountObj != null ? falsePositiveCountObj.doubleValue()
                        : 0.0;
                double falseNegativeCount = falseNegativeCountObj != null ? falseNegativeCountObj.doubleValue()
                        : 0.0;
                LOG.debug("truePositiveCount: " + truePositiveCount);
                LOG.debug("falsePositiveCount: " + falsePositiveCount);
                if (LOG.isTraceEnabled()) {
                    LOG.debug("False positives: ");
                    Map<E, Integer> pairCounts = falsePositives.get(outcome);
                    if (pairCounts != null) {
                        for (E guessed : pairCounts.keySet()) {
                            int pairCount = pairCounts.get(guessed);
                            LOG.trace(outcome.toString() + " , " + guessed.toString() + ": " + pairCount);
                        }
                    }
                }

                LOG.debug("falseNegativeCount " + falseNegativeCount);
                if (LOG.isTraceEnabled()) {
                    LOG.debug("False negatives: ");
                    Map<E, Integer> pairCounts = falseNegatives.get(outcome);
                    if (pairCounts != null) {
                        for (E expected : pairCounts.keySet()) {
                            int pairCount = pairCounts.get(expected);
                            LOG.trace(outcome.toString() + " , " + expected.toString() + ": " + pairCount);
                        }
                    }
                }

                double precision = 0;
                double recall = 0;
                double fScore = 0;

                if (truePositiveCount + falsePositiveCount > 0)
                    precision = truePositiveCount / (truePositiveCount + falsePositiveCount);
                if (truePositiveCount + falseNegativeCount > 0)
                    recall = truePositiveCount / (truePositiveCount + falseNegativeCount);
                if (precision + recall > 0)
                    fScore = (2 * precision * recall) / (precision + recall);
                LOG.debug("Precision: " + precision);
                LOG.debug("Recall: " + recall);
                LOG.debug("F-score " + fScore);

                precisions.put(outcome, precision);
                recalls.put(outcome, recall);
                fScores.put(outcome, fScore);
                totalTruePositiveCount += truePositiveCount;
                totalFalsePositiveCount += falsePositiveCount;
                totalFalseNegativeCount += falseNegativeCount;
            }
            totalPrecision = (double) totalTruePositiveCount
                    / ((double) totalTruePositiveCount + (double) totalFalsePositiveCount);
            totalRecall = (double) totalTruePositiveCount
                    / ((double) totalTruePositiveCount + (double) totalFalseNegativeCount);
            totalFScore = (2 * totalPrecision * totalRecall) / (totalPrecision + totalRecall);
            LOG.info("Total tests: " + testCount);
            LOG.info("Total true positives: " + totalTruePositiveCount);
            LOG.info("Total false positives: " + totalFalsePositiveCount);
            LOG.info("Total false negatives: " + totalFalseNegativeCount);
            LOG.info("Total precision: " + totalPrecision);
            LOG.info("Total recall: " + totalRecall);
            LOG.info("Total f-score: " + totalFScore);

            updatedSinceLastEval = false;
        }
    }

    /**
     * Get the count of false positives for a given outcome.
     * @param outcome
     * @return
     */
    public int getFalsePositiveCount(E outcome) {
        int falsePositiveCount = 0;
        Integer falsePositiveCountObj = falsePositiveCounts.get(outcome);
        if (falsePositiveCountObj != null)
            falsePositiveCount = falsePositiveCountObj.intValue();
        return falsePositiveCount;
    }

    /**
     * Get the false positives for a given outcome.
     * @param outcome
     * @return
     */
    public Map<E, Integer> getFalsePositives(E outcome) {
        return falsePositives.get(outcome);
    }

    /**
     * False positives for all outcomes.
     * @return
     */
    public Map<E, Integer> getFalsePositiveCounts() {
        return falsePositiveCounts;
    }

    /**
     * Get the count of false negatives for a given outcome.
     * @param outcome
     * @return
     */
    public int getFalseNegativeCount(E outcome) {
        int falseNegativeCount = 0;
        Integer falseNegativeCountObj = falseNegativeCounts.get(outcome);
        if (falseNegativeCountObj != null)
            falseNegativeCount = falseNegativeCountObj.intValue();
        return falseNegativeCount;
    }

    /**
     * Get the false negatives for a given outcome.
     * @param outcome
     * @return
     */
    public Map<E, Integer> getFalseNegatives(E outcome) {
        return falseNegatives.get(outcome);
    }

    /**
     * False negatives for all outcomes.
     * @return
     */
    public Map<E, Integer> getFalseNegativeCounts() {
        return falseNegativeCounts;
    }

    /**
     * Get the count of true positives for a given outcome.
     * @param outcome
     * @return
     */
    public int getTruePositiveCount(E outcome) {
        int truePositiveCount = 0;
        Integer truePositiveCountObj = truePositiveCounts.get(outcome);
        if (truePositiveCountObj != null)
            truePositiveCount = truePositiveCountObj.intValue();
        return truePositiveCount;
    }

    /**
     * True positive counts for all outcomes.
     * @return
     */
    public Map<E, Integer> getTruePositiveCounts() {
        return truePositiveCounts;
    }

    /**
     * The set of outcomes.
     * @return
     */
    public Set<E> getOutcomeSet() {
        return outcomeSet;
    }

    /**
     * Get the precision for a particular outcome.
     * @param outcome
     * @return
     */
    public double getPrecision(E outcome) {
        this.evaluate();
        return precisions.get(outcome);
    }

    /**
     * Precisions for all outcomes.
     * @return
     */
    public Map<E, Double> getPrecisions() {
        this.evaluate();
        return precisions;
    }

    /**
     * Get the recall for a particular outcome.
     * @param outcome
     * @return
     */
    public double getRecall(E outcome) {
        this.evaluate();
        return recalls.get(outcome);
    }

    /**
     * Recalls for all outcomes.
     * @return
     */
    public Map<E, Double> getRecalls() {
        this.evaluate();
        return recalls;
    }

    /**
     * Get the f-score for a particular outcome.
     * @param outcome
     * @return
     */
    public double getFScore(E outcome) {
        this.evaluate();
        return fScores.get(outcome);
    }

    /**
     * F-scores for all outcomes.
     * @return
     */
    public Map<E, Double> getFScores() {
        this.evaluate();
        return fScores;
    }

    /**
     * Total number of tests run.
     * @return
     */
    public int getTestCount() {
        return testCount;
    }

    /**
     * Total precision.
     * @return
     */
    public double getTotalPrecision() {
        this.evaluate();
        return totalPrecision;
    }

    /**
     * Total recall.
     * @return
     */
    public double getTotalRecall() {
        this.evaluate();
        return totalRecall;
    }

    /**
     * Total f-score.
     * @return
     */
    public double getTotalFScore() {
        this.evaluate();
        return totalFScore;
    }

    public int getTotalTruePositiveCount() {
        this.evaluate();
        return totalTruePositiveCount;
    }

    public int getTotalFalsePositiveCount() {
        this.evaluate();
        return totalFalsePositiveCount;
    }

    public int getTotalFalseNegativeCount() {
        this.evaluate();
        return totalFalseNegativeCount;
    }

    public void writeScoresToCSVFile(File fscoreFile) {
        try {
            fscoreFile.delete();
            fscoreFile.createNewFile();
            Writer fscoreFileWriter = new BufferedWriter(
                    new OutputStreamWriter(new FileOutputStream(fscoreFile, false), "UTF8"));
            try {
                this.writeScoresToCSV(fscoreFileWriter);
            } finally {
                fscoreFileWriter.flush();
                fscoreFileWriter.close();
            }
        } catch (IOException ioe) {
            throw new RuntimeException(ioe);
        }
    }

    /**
     * The total accuracy for this confusion matrix.
     * @return
     */
    public double getAccuracy() {
        double totalCount = (double) testCount;
        double totalAccuracy = (double) totalTruePositiveCount / totalCount;
        return totalAccuracy;
    }

    /**
     * Return Cohen's kappa for this confusion matrix.
     * @return
     */
    public double getKappa() {
        double totalCount = (double) testCount;
        double totalAccuracy = (double) totalTruePositiveCount / totalCount;

        // random accuracy is the sum of products for marginal accuracies for each label
        double randomAccuracy = 0.0;
        for (E outcome : outcomeSet) {
            Integer truePositiveCountObj = truePositiveCounts.get(outcome);
            Integer falsePositiveCountObj = falsePositiveCounts.get(outcome);
            Integer falseNegativeCountObj = falseNegativeCounts.get(outcome);
            double truePositiveCount = truePositiveCountObj != null ? truePositiveCountObj.doubleValue() : 0.0;
            double falsePositiveCount = falsePositiveCountObj != null ? falsePositiveCountObj.doubleValue() : 0.0;
            double falseNegativeCount = falseNegativeCountObj != null ? falseNegativeCountObj.doubleValue() : 0.0;
            double marginalRandomAccuracy = ((truePositiveCount + falsePositiveCount) / totalCount)
                    * ((truePositiveCount + falseNegativeCount) / totalCount);
            randomAccuracy += marginalRandomAccuracy;
        }

        double kappa = (totalAccuracy - randomAccuracy) / (1 - randomAccuracy);
        return kappa;
    }

    public void writeScoresToCSV(Writer fscoreFileWriter) {
        try {
            Set<E> outcomeSet = new TreeSet<E>();
            outcomeSet.addAll(this.getOutcomeSet());
            fscoreFileWriter.write(CSV.format("outcome"));
            for (E outcome : outcomeSet) {
                fscoreFileWriter.write(CSV.format(outcome.toString()));
            }
            fscoreFileWriter.write(CSV.format("true+") + CSV.format("false+") + CSV.format("false-")
                    + CSV.format("precision") + CSV.format("recall") + CSV.format("f-score"));
            fscoreFileWriter.write("\n");

            double totalPrecisionSum = 0;
            double totalRecallSum = 0;
            double totalFscoreSum = 0;
            for (E outcome : outcomeSet) {
                fscoreFileWriter.write(CSV.format(outcome.toString()));
                for (E outcome2 : outcomeSet) {
                    int falseNegativeCount = 0;
                    Map<E, Integer> falseNegatives = this.getFalseNegatives(outcome);
                    if (falseNegatives != null && falseNegatives.containsKey(outcome2)) {
                        falseNegativeCount = this.getFalseNegatives(outcome).get(outcome2);
                    }
                    fscoreFileWriter.write(CSV.format(falseNegativeCount));
                }
                fscoreFileWriter.write(CSV.format(this.getTruePositiveCount(outcome)));
                fscoreFileWriter.write(CSV.format(this.getFalsePositiveCount(outcome)));
                fscoreFileWriter.write(CSV.format(this.getFalseNegativeCount(outcome)));
                fscoreFileWriter.write(CSV.format(this.getPrecision(outcome) * 100));
                fscoreFileWriter.write(CSV.format(this.getRecall(outcome) * 100));
                fscoreFileWriter.write(CSV.format(this.getFScore(outcome) * 100));
                fscoreFileWriter.write("\n");

                totalPrecisionSum += this.getPrecision(outcome);
                totalRecallSum += this.getRecall(outcome);
                totalFscoreSum += this.getFScore(outcome);
            }

            fscoreFileWriter.write(CSV.format("TOTAL"));
            for (E outcome : outcomeSet) {
                outcome.hashCode();
                fscoreFileWriter.write(CSV.format(""));
            }
            fscoreFileWriter.write(CSV.format(this.getTotalTruePositiveCount()));
            fscoreFileWriter.write(CSV.format(this.getTotalFalsePositiveCount()));
            fscoreFileWriter.write(CSV.format(this.getTotalFalseNegativeCount()));
            fscoreFileWriter.write(CSV.format(this.getTotalPrecision() * 100));
            fscoreFileWriter.write(CSV.format(this.getTotalRecall() * 100));
            fscoreFileWriter.write(CSV.format(this.getTotalFScore() * 100));
            fscoreFileWriter.write("\n");

            fscoreFileWriter.write(CSV.format("AVERAGE"));
            for (E outcome : outcomeSet) {
                outcome.hashCode();
                fscoreFileWriter.write(CSV.format(""));
            }
            fscoreFileWriter.write(CSV.format(""));
            fscoreFileWriter.write(CSV.format(""));
            fscoreFileWriter.write(CSV.format(""));
            fscoreFileWriter.write(CSV.format((totalPrecisionSum / outcomeSet.size()) * 100));
            fscoreFileWriter.write(CSV.format((totalRecallSum / outcomeSet.size()) * 100));
            fscoreFileWriter.write(CSV.format((totalFscoreSum / outcomeSet.size()) * 100));
            fscoreFileWriter.write("\n");

            fscoreFileWriter.write(CSV.format("ACCURACY"));
            for (E outcome : outcomeSet) {
                outcome.hashCode();
                fscoreFileWriter.write(CSV.format(""));
            }
            fscoreFileWriter.write(CSV.format(""));
            fscoreFileWriter.write(CSV.format(""));
            fscoreFileWriter.write(CSV.format(""));
            fscoreFileWriter.write(CSV.format(this.getAccuracy() * 100));
            fscoreFileWriter.write("\n");

            fscoreFileWriter.write(CSV.format("KAPPA"));
            for (E outcome : outcomeSet) {
                outcome.hashCode();
                fscoreFileWriter.write(CSV.format(""));
            }
            fscoreFileWriter.write(CSV.format(""));
            fscoreFileWriter.write(CSV.format(""));
            fscoreFileWriter.write(CSV.format(""));
            fscoreFileWriter.write(CSV.format(this.getKappa() * 100));
            fscoreFileWriter.write("\n");

        } catch (IOException ioe) {
            throw new RuntimeException(ioe);
        }
    }

    public static void main(String[] args) throws Exception {
        File directory = new File(args[0]);
        String prefix = args[1];
        String suffix = args[2];
        Writer csvFileWriter = new BufferedWriter(new OutputStreamWriter(
                new FileOutputStream(new File(directory, prefix + "A" + suffix + ".csv"), false), "UTF8"));
        combineCrossValidationResults(directory, prefix, suffix, csvFileWriter);
    }

    /**
     * Combine the results of n cross validation results into a single f-score file.
     * @param directory
     * @param prefix
     * @param suffix
     * @param csvFileWriter
     */
    static void combineCrossValidationResults(File directory, String prefix, String suffix, Writer csvFileWriter) {
        try {
            File[] files = directory.listFiles();
            Map<Integer, Map<String, FScoreStats>> fileStatsMap = new HashMap<Integer, Map<String, FScoreStats>>();
            for (File file : files) {
                if (file.getName().startsWith(prefix) && file.getName().endsWith(suffix)) {
                    int index = Integer.parseInt(file.getName().substring(prefix.length(), prefix.length() + 1));
                    Map<String, FScoreStats> statsMap = new HashMap<String, FScoreCalculator.FScoreStats>();
                    fileStatsMap.put(index, statsMap);
                    Scanner scanner = new Scanner(
                            new BufferedReader(new InputStreamReader(new FileInputStream(file), "UTF-8")));

                    boolean firstLine = true;
                    int truePositivePos = -1;

                    while (scanner.hasNextLine()) {
                        String line = scanner.nextLine();
                        List<String> cells = CSV.getCSVCells(line);
                        if (firstLine) {
                            int i = 0;
                            for (String cell : cells) {
                                if (cell.equals("true+")) {
                                    truePositivePos = i;
                                    break;
                                }
                                i++;
                            }
                            if (truePositivePos < 0) {
                                throw new JolicielException("Couldn't find true+ on first line");
                            }
                            firstLine = false;
                        } else {
                            FScoreStats stats = new FScoreStats();
                            String outcome = cells.get(0);
                            stats.outcome = outcome;
                            if (outcome.equals("AVERAGE"))
                                break;
                            stats.truePos = Integer.parseInt(cells.get(truePositivePos));
                            stats.falsePos = Integer.parseInt(cells.get(truePositivePos + 1));
                            stats.falseNeg = Integer.parseInt(cells.get(truePositivePos + 2));
                            stats.precision = Double.parseDouble(cells.get(truePositivePos + 3));
                            stats.recall = Double.parseDouble(cells.get(truePositivePos + 4));
                            stats.fScore = Double.parseDouble(cells.get(truePositivePos + 5));
                            statsMap.put(outcome, stats);
                        } // firstLine?
                    } // has more lines
                    scanner.close();
                } // file in current series
            } // next file

            int numFiles = fileStatsMap.size();
            if (numFiles == 0) {
                throw new JolicielException("No files found matching prefix and suffix provided");
            }
            Map<String, DescriptiveStatistics> descriptiveStatsMap = new HashMap<String, DescriptiveStatistics>();
            Map<String, FScoreStats> outcomeStats = new HashMap<String, FScoreCalculator.FScoreStats>();
            Set<String> outcomes = new TreeSet<String>();
            for (Map<String, FScoreStats> statsMap : fileStatsMap.values()) {
                for (FScoreStats stats : statsMap.values()) {
                    DescriptiveStatistics fScoreStats = descriptiveStatsMap.get(stats.outcome + "fScore");
                    if (fScoreStats == null) {
                        fScoreStats = new DescriptiveStatistics();
                        descriptiveStatsMap.put(stats.outcome + "fScore", fScoreStats);
                    }
                    fScoreStats.addValue(stats.fScore);
                    DescriptiveStatistics precisionStats = descriptiveStatsMap.get(stats.outcome + "precision");
                    if (precisionStats == null) {
                        precisionStats = new DescriptiveStatistics();
                        descriptiveStatsMap.put(stats.outcome + "precision", precisionStats);
                    }
                    precisionStats.addValue(stats.precision);
                    DescriptiveStatistics recallStats = descriptiveStatsMap.get(stats.outcome + "recall");
                    if (recallStats == null) {
                        recallStats = new DescriptiveStatistics();
                        descriptiveStatsMap.put(stats.outcome + "recall", recallStats);
                    }
                    recallStats.addValue(stats.recall);

                    FScoreStats outcomeStat = outcomeStats.get(stats.outcome);
                    if (outcomeStat == null) {
                        outcomeStat = new FScoreStats();
                        outcomeStat.outcome = stats.outcome;
                        outcomeStats.put(stats.outcome, outcomeStat);
                    }
                    outcomeStat.truePos += stats.truePos;
                    outcomeStat.falsePos += stats.falsePos;
                    outcomeStat.falseNeg += stats.falseNeg;

                    outcomes.add(stats.outcome);
                }
            }

            csvFileWriter.write(CSV.format(prefix + suffix));
            csvFileWriter.write("\n");
            csvFileWriter.write(CSV.format("outcome"));
            csvFileWriter.write(CSV.format("true+") + CSV.format("false+") + CSV.format("false-")
                    + CSV.format("tot precision") + CSV.format("avg precision") + CSV.format("dev precision")
                    + CSV.format("tot recall") + CSV.format("avg recall") + CSV.format("dev recall")
                    + CSV.format("tot f-score") + CSV.format("avg f-score") + CSV.format("dev f-score") + "\n");

            for (String outcome : outcomes) {
                csvFileWriter.write(CSV.format(outcome));
                FScoreStats outcomeStat = outcomeStats.get(outcome);
                DescriptiveStatistics fScoreStats = descriptiveStatsMap.get(outcome + "fScore");
                DescriptiveStatistics precisionStats = descriptiveStatsMap.get(outcome + "precision");
                DescriptiveStatistics recallStats = descriptiveStatsMap.get(outcome + "recall");
                outcomeStat.calculate();
                csvFileWriter.write(CSV.format(outcomeStat.truePos));
                csvFileWriter.write(CSV.format(outcomeStat.falsePos));
                csvFileWriter.write(CSV.format(outcomeStat.falseNeg));
                csvFileWriter.write(CSV.format(outcomeStat.precision * 100));
                csvFileWriter.write(CSV.format(precisionStats.getMean()));
                csvFileWriter.write(CSV.format(precisionStats.getStandardDeviation()));
                csvFileWriter.write(CSV.format(outcomeStat.recall * 100));
                csvFileWriter.write(CSV.format(recallStats.getMean()));
                csvFileWriter.write(CSV.format(recallStats.getStandardDeviation()));
                csvFileWriter.write(CSV.format(outcomeStat.fScore * 100));
                csvFileWriter.write(CSV.format(fScoreStats.getMean()));
                csvFileWriter.write(CSV.format(fScoreStats.getStandardDeviation()));
                csvFileWriter.write("\n");
                csvFileWriter.flush();
            }
        } catch (IOException ioe) {
            throw new RuntimeException(ioe);
        }
    }

    private static final class FScoreStats {
        String outcome;
        int truePos;
        int falsePos;
        int falseNeg;
        double precision;
        double recall;
        double fScore;

        public void calculate() {
            if (truePos + falsePos > 0)
                precision = (double) truePos / (double) (truePos + falsePos);
            if (truePos + falseNeg > 0)
                recall = (double) truePos / (double) (truePos + falseNeg);
            if (precision + recall > 0)
                fScore = (2 * precision * recall) / (precision + recall);
        }
    }
}