de.tudarmstadt.ukp.dkpro.argumentation.sequence.report.TokenLevelEvaluationReport.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.dkpro.argumentation.sequence.report.TokenLevelEvaluationReport.java

Source

/*
 * Copyright 2015
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package de.tudarmstadt.ukp.dkpro.argumentation.sequence.report;

import de.tudarmstadt.ukp.dkpro.argumentation.sequence.feature.meta.AbstractSequenceMetaDataFeatureGenerator;
import de.tudarmstadt.ukp.dkpro.argumentation.sequence.feature.meta.OrigBIOTokenSequenceMetaDataFeatureGenerator;
import de.tudarmstadt.ukp.dkpro.argumentation.sequence.feature.meta.OrigTokenSequenceMetaDataFeatureGenerator;
import de.tudarmstadt.ukp.dkpro.lab.engine.TaskContext;
import de.tudarmstadt.ukp.dkpro.lab.storage.StorageService;
import de.tudarmstadt.ukp.dkpro.tc.svmhmm.report.SVMHMMOutcomeIDReport;
import de.tudarmstadt.ukp.dkpro.tc.svmhmm.util.ConfusionMatrix;
import de.tudarmstadt.ukp.dkpro.tc.svmhmm.util.SVMHMMUtils;
import org.apache.commons.csv.CSVPrinter;
import org.apache.commons.io.IOUtils;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.List;
import java.util.SortedMap;

/**
 * @author Ivan Habernal
 */
public class TokenLevelEvaluationReport extends SVMHMMOutcomeIDReport {

    public static final String TOKEN_LEVEL_PREDICTIONS_CSV = "tokenLevelPredictions.csv";

    @Override
    public void execute() throws Exception {
        // load gold and predicted labels
        loadGoldAndPredictedLabels();

        File testFile = locateTestFile();

        // sequence IDs
        List<Integer> sequenceIDs = SVMHMMUtils.extractOriginalSequenceIDs(testFile);

        // meta data original
        List<SortedMap<String, String>> metaDataFeatures = SVMHMMUtils.extractMetaDataFeatures(testFile);

        // sanity check
        if (goldLabels.size() != sequenceIDs.size() || goldLabels.size() != metaDataFeatures.size()) {
            throw new IllegalStateException("check consistency");
        }

        File evaluationFile = new File(
                getContext().getStorageLocation(TEST_TASK_OUTPUT_KEY, StorageService.AccessMode.READWRITE),
                TOKEN_LEVEL_PREDICTIONS_CSV);

        // write results into CSV
        // form: gold;predicted;token;seqID
        CSVPrinter csvPrinter = new CSVPrinter(new FileWriter(evaluationFile), SVMHMMUtils.CSV_FORMAT);
        csvPrinter.printComment(SVMHMMUtils.CSV_COMMENT);

        // confusion matrix for evaluation
        ConfusionMatrix confusionMatrix = new ConfusionMatrix();

        for (int i = 0; i < goldLabels.size(); i++) {
            String predictedLabelSentenceLevel = predictedLabels.get(i);

            // get gold token labels for this sentence
            List<String> goldTokenLabels = AbstractSequenceMetaDataFeatureGenerator.decodeFromString(
                    metaDataFeatures.get(i).get(OrigBIOTokenSequenceMetaDataFeatureGenerator.FEATURE_NAME));
            // get tokens for this sentence
            List<String> tokens = AbstractSequenceMetaDataFeatureGenerator.decodeFromString(
                    metaDataFeatures.get(i).get(OrigTokenSequenceMetaDataFeatureGenerator.FEATURE_NAME));
            // predicted token labels
            List<String> recreatedPredictedTokenLabels = ReportTools
                    .recreateTokenLabels(predictedLabelSentenceLevel, goldTokenLabels.size());

            for (int j = 0; j < goldTokenLabels.size(); j++) {
                String tokenGold = goldTokenLabels.get(j);
                String tokenPredicted = recreatedPredictedTokenLabels.get(j);

                // write to csv
                csvPrinter.printRecord(tokenGold, tokenPredicted, tokens.get(j), sequenceIDs.get(i).toString());

                // add to matrix
                confusionMatrix.increaseValue(tokenGold, tokenPredicted);
            }
        }

        IOUtils.closeQuietly(csvPrinter);

        // and write to the output
        writeResults(getContext(), confusionMatrix);
    }

    protected void writeResults(TaskContext context, ConfusionMatrix confusionMatrix) throws IOException {
        SVMHMMUtils.writeOutputResults(context, confusionMatrix, "tokenLevel_");
    }
}