de.tudarmstadt.ukp.dkpro.argumentation.sequence.report.TokenLevelBatchCrossValidationReport.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.dkpro.argumentation.sequence.report.TokenLevelBatchCrossValidationReport.java

Source

/*
 * Copyright 2015
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package de.tudarmstadt.ukp.dkpro.argumentation.sequence.report;

import de.tudarmstadt.ukp.dkpro.lab.storage.StorageService;
import de.tudarmstadt.ukp.dkpro.lab.task.impl.ExecutableTaskBase;
import de.tudarmstadt.ukp.dkpro.tc.core.Constants;
import de.tudarmstadt.ukp.dkpro.tc.svmhmm.report.SVMHMMBatchCrossValidationReport;
import de.tudarmstadt.ukp.dkpro.tc.svmhmm.task.SVMHMMTestTask;
import de.tudarmstadt.ukp.dkpro.tc.svmhmm.util.ConfusionMatrix;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import org.apache.log4j.Logger;

import java.io.File;
import java.io.FileReader;
import java.io.IOException;

/**
 * @author Ivan Habernal
 */
public class TokenLevelBatchCrossValidationReport // this is the default result collector
        extends SVMHMMBatchCrossValidationReport {
    static Logger log = Logger.getLogger(TokenLevelBatchCrossValidationReport.class);

    public static final String RESULT_SUMMARY = "resultSummary.txt";

    @Override
    public void execute() throws Exception {
        //        super.execute();
        aggregateResults(TokenLevelEvaluationReport.TOKEN_LEVEL_PREDICTIONS_CSV, "tokenLevel_");
        reportOnlyMacroFM();
    }

    protected void reportOnlyMacroFM() throws IOException {
        File aggregatedCSVFile = new File(
                getContext().getStorageLocation(Constants.TEST_TASK_OUTPUT_KEY, StorageService.AccessMode.READONLY),
                TokenLevelEvaluationReport.TOKEN_LEVEL_PREDICTIONS_CSV);

        // load the CSV
        CSVParser csvParser = new CSVParser(new FileReader(aggregatedCSVFile),
                CSVFormat.DEFAULT.withCommentMarker('#'));

        // compute confusion matrix
        ConfusionMatrix cm = new ConfusionMatrix();

        // and add the all rows
        for (CSVRecord csvRecord : csvParser) {
            // first item is the gold label
            String gold = csvRecord.get(0);
            // second item is the predicted label
            String predicted = csvRecord.get(1);

            cm.increaseValue(gold, predicted);
        }

        File evaluationFile = new File(getContext().getStorageLocation(Constants.TEST_TASK_OUTPUT_KEY,
                StorageService.AccessMode.READWRITE), RESULT_SUMMARY);

        ReportTools.printFMeasuresToFile(cm, evaluationFile);
    }

    @Override
    protected Class<? extends ExecutableTaskBase> getTestTaskClass() {
        //        return SVMHMMRandomTestTask.class;
        return SVMHMMTestTask.class;
    }
}