org.apache.ctakes.relationextractor.eval.SHARPXMI.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.ctakes.relationextractor.eval.SHARPXMI.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.ctakes.relationextractor.eval;

import com.google.common.base.Function;
import com.google.common.base.Functions;
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import com.lexicalscope.jewel.cli.Option;
import org.apache.ctakes.core.ae.SHARPKnowtatorXMLReader;
import org.apache.ctakes.core.util.DocumentIDAnnotationUtil;
import org.apache.ctakes.typesystem.type.structured.DocumentID;
import org.apache.uima.UIMAFramework;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.CAS;
import org.apache.uima.cas.CASException;
import org.apache.uima.cas.impl.XmiCasSerializer;
import org.apache.uima.collection.CollectionReader;
import org.apache.uima.fit.component.JCasAnnotator_ImplBase;
import org.apache.uima.fit.component.ViewCreatorAnnotator;
import org.apache.uima.fit.factory.AggregateBuilder;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.factory.CollectionReaderFactory;
import org.apache.uima.fit.factory.TypeSystemDescriptionFactory;
import org.apache.uima.fit.pipeline.JCasIterator;
import org.apache.uima.jcas.JCas;
import org.apache.uima.util.XMLInputSource;
import org.apache.uima.util.XMLParser;
import org.apache.uima.util.XMLSerializer;
import org.cleartk.eval.AnnotationStatistics;
import org.cleartk.util.ViewUriUtil;
import org.cleartk.util.ae.UriToDocumentTextAnnotator;
import org.cleartk.util.cr.UriCollectionReader;
import org.xml.sax.ContentHandler;

import java.io.File;
import java.io.FileOutputStream;
import java.util.*;
import java.util.regex.Pattern;

public class SHARPXMI {

    public static List<File> getTrainTextFiles(File batchesDirectory) {
        // seed_set1: batches 2, 3, 4, 5, 6, 7, 8, 9, 13, 14, 15, 16, 18, 19
        // seed_set2: batches 1, 2, 3, 4, 5, 6, 7, 8, 9, 13, 14, 15, 16, 18, 19
        // seed_set3: batches 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 15, 16, 18, 19
        // seed_set4: batches 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 15, 16, 18, 19
        return getTextFilesFor(batchesDirectory, Pattern.compile("^(ss[1234]_batch0[2-9]|ss[1234]_batch1[56]"
                + "|ss[1234]_batch1[89]|ss[123]_batch01" + "|ss[12]_batch1[34]|ss[34]_batch1[12])$"));
    }

    public static List<File> getDevTextFiles(File batchesDirectory) {
        // seed_set1: batches 10, 17
        // seed_set2: batches 10, 17
        // seed_set3: batches 10, 17
        // seed_set4: batches 10, 17
        return getTextFilesFor(batchesDirectory, Pattern.compile("^(ss[1234]_batch1[07])$"));
    }

    public static List<File> getTestTextFiles(File batchesDirectory) {
        // seed_set1: batches 11, 12
        // seed_set2: batches 11, 12
        // seed_set3: batches 13, 14
        // seed_set4: batches 13, 14
        return getTextFilesFor(batchesDirectory, Pattern.compile("^(ss[12]_batch1[12]|ss[34]_batch1[34])$"));
    }

    public static List<File> getAllTextFiles(File batchesDirectory) {
        return getTextFilesFor(batchesDirectory, Pattern.compile(""));
    }

    private static List<File> getTextFilesFor(File batchesDirectory, Pattern pattern) {
        List<File> files = Lists.newArrayList();
        for (File batchDir : batchesDirectory.listFiles()) {
            if (batchDir.isDirectory() && !batchDir.isHidden()) {
                if (pattern.matcher(batchDir.getName()).find()) {
                    File textDirectory = new File(batchDir, "Knowtator/text");
                    for (File textFile : textDirectory.listFiles()) {
                        if (textFile.isFile() && !textFile.isHidden()) {
                            files.add(textFile);
                        }
                    }
                }
            }
        }
        return files;
    }

    public static List<File> toXMIFiles(Options options, List<File> textFiles) {
        List<File> xmiFiles = Lists.newArrayList();
        for (File textFile : textFiles) {
            xmiFiles.add(toXMIFile(options, textFile));
        }
        return xmiFiles;
    }

    private static File toXMIFile(Options options, File textFile) {
        return new File(options.getXMIDirectory(), textFile.getName() + ".xmi");
    }

    public static interface Options {
        @Option(longName = "batches-dir", description = "directory containing ssN_batchNN directories, each of which should contain "
                + "a Knowtator directory and a Knowtator_XML directory")
        public File getBatchesDirectory();

        @Option(longName = "xmi-dir", defaultValue = "target/xmi", description = "directory to store and load XMI serialization of annotations")
        public File getXMIDirectory();

        @Option(longName = "generate-xmi", description = "read in the gold annotations and serialize them as XMI")
        public boolean getGenerateXMI();
    }

    public static final String GOLD_VIEW_NAME = "GoldView";

    public static void generateXMI(Options options) throws Exception {
        // if necessary, write the XMIs first
        if (options.getGenerateXMI()) {
            if (!options.getXMIDirectory().exists()) {
                options.getXMIDirectory().mkdirs();
            }

            // create a collection reader that loads URIs for all Knowtator text files
            List<File> files = Lists.newArrayList();
            files.addAll(getTrainTextFiles(options.getBatchesDirectory()));
            files.addAll(getDevTextFiles(options.getBatchesDirectory()));
            files.addAll(getTestTextFiles(options.getBatchesDirectory()));
            CollectionReader reader = UriCollectionReader.getCollectionReaderFromFiles(files);

            // load the text from the URI, run the preprocessor, then run the
            // Knowtator XML reader
            AggregateBuilder builder = new AggregateBuilder();
            builder.add(UriToDocumentTextAnnotator.getDescription());
            File preprocessDescFile = new File("desc/analysis_engine/RelationExtractorPreprocessor.xml");
            XMLParser parser = UIMAFramework.getXMLParser();
            XMLInputSource source = new XMLInputSource(preprocessDescFile);
            builder.add(parser.parseAnalysisEngineDescription(source));
            builder.add(AnalysisEngineFactory.createEngineDescription(ViewCreatorAnnotator.class,
                    ViewCreatorAnnotator.PARAM_VIEW_NAME, GOLD_VIEW_NAME));
            builder.add(AnalysisEngineFactory.createEngineDescription(CopyDocumentTextToGoldView.class));
            builder.add(AnalysisEngineFactory.createEngineDescription(DocumentIDAnnotator.class),
                    CAS.NAME_DEFAULT_SOFA, GOLD_VIEW_NAME);
            builder.add(
                    AnalysisEngineFactory.createEngineDescription(SHARPKnowtatorXMLReader.class,
                            SHARPKnowtatorXMLReader.PARAM_SET_DEFAULTS, true),
                    CAS.NAME_DEFAULT_SOFA, GOLD_VIEW_NAME);

            // write out an XMI for each file
            for (Iterator<JCas> casIter = new JCasIterator(reader, builder.createAggregate()); casIter.hasNext();) {
                JCas jCas = casIter.next();
                JCas goldView = jCas.getView(GOLD_VIEW_NAME);
                String documentID = DocumentIDAnnotationUtil.getDocumentID(goldView);
                if (documentID == null) {//|| documentID.equals( DocumentIDAnnotationUtil.NO_DOCUMENT_ID ) ) {
                    throw new IllegalArgumentException("No documentID for CAS:\n" + jCas);
                }
                File outFile = toXMIFile(options, new File(documentID));
                FileOutputStream stream = new FileOutputStream(outFile);
                ContentHandler handler = new XMLSerializer(stream).getContentHandler();
                new XmiCasSerializer(jCas.getTypeSystem()).serialize(jCas.getCas(), handler);
                stream.close();
            }
        }
    }

    public enum EvaluateOn {
        TRAIN, DEV, TEST, OTHER
    }

    public static interface EvaluationOptions extends Options {
        @Option(longName = "evaluate-on", defaultValue = "DEV", description = "perform evaluation using the training (TRAIN), development (DEV) or test "
                + "(TEST) data.")
        public EvaluateOn getEvaluteOn();

        @Option(longName = "grid-search", description = "run a grid search to select the best parameters")
        public boolean getGridSearch();

        @Option(defaultToNull = true, longName = "train-xmi-dir", description = "use these XMI files for training; they must contain the necessary preprocessing "
                + "in system view and gold annotation in gold view")
        public File getTrainXmiDir();

        @Option(longName = "test-xmi-dir", defaultValue = "", description = "evaluate on these XMI files; they must contain the necessary preprocessing "
                + "in system view and gold annotation in gold view")
        public File getTestXmiDir();
    }

    public static abstract class Evaluation_ImplBase
            extends org.cleartk.eval.Evaluation_ImplBase<File, AnnotationStatistics<String>> {

        public Evaluation_ImplBase(File baseDirectory) {
            super(baseDirectory);
        }

        @Override
        public CollectionReader getCollectionReader(List<File> items) throws Exception {
            return CollectionReaderFactory.createReader(XMIReader.class,
                    TypeSystemDescriptionFactory.createTypeSystemDescription(), XMIReader.PARAM_FILES, items);
        }
    }

    public static void validate(EvaluationOptions options) throws Exception {
        // error on invalid option combinations
        if (options.getEvaluteOn().equals(EvaluateOn.TEST) && options.getGridSearch()) {
            throw new IllegalArgumentException("grid search can only be run on the train or dev sets");
        }
    }

    public static <T extends Evaluation_ImplBase> void evaluate(EvaluationOptions options,
            ParameterSettings bestSettings, List<ParameterSettings> gridOfSettings,
            Function<ParameterSettings, T> getEvaluation) throws Exception {
        // define the set of possible training parameters
        List<ParameterSettings> possibleParams;
        if (options.getGridSearch()) {
            possibleParams = gridOfSettings;
        } else {
            possibleParams = Lists.newArrayList(bestSettings);
        }

        // run an evaluation for each set of parameters
        Map<ParameterSettings, Double> scoredParams = new HashMap<>();
        for (ParameterSettings params : possibleParams) {
            Evaluation_ImplBase evaluation = getEvaluation.apply(params);

            List<File> trainFiles, devFiles, testFiles;
            switch (options.getEvaluteOn()) {
            case TRAIN:
                // run n-fold cross-validation on the training set
                trainFiles = getTrainTextFiles(options.getBatchesDirectory());
                trainFiles = toXMIFiles(options, trainFiles);
                List<AnnotationStatistics<String>> foldStats = evaluation.crossValidation(trainFiles, 2);
                params.stats = AnnotationStatistics.addAll(foldStats);
                break;
            case DEV:
                // train on the training set and evaluate on the dev set
                trainFiles = getTrainTextFiles(options.getBatchesDirectory());
                trainFiles = toXMIFiles(options, trainFiles);
                devFiles = getDevTextFiles(options.getBatchesDirectory());
                devFiles = toXMIFiles(options, devFiles);
                params.stats = evaluation.trainAndTest(trainFiles, devFiles);
                break;
            case TEST:
                // train on the training set + dev set and evaluate on the test set
                List<File> allTrainFiles = new ArrayList<>();
                allTrainFiles.addAll(getTrainTextFiles(options.getBatchesDirectory()));
                allTrainFiles.addAll(getDevTextFiles(options.getBatchesDirectory()));
                allTrainFiles = toXMIFiles(options, allTrainFiles);
                testFiles = getTestTextFiles(options.getBatchesDirectory());
                testFiles = toXMIFiles(options, testFiles);
                params.stats = evaluation.trainAndTest(allTrainFiles, testFiles);
                break;
            case OTHER:
                // train on train + dev + specified train xmis and test on specified test xmi files
                // these files should have the necessary preprocessing in the initial view
                // and gold standard relation annotations in the gold view
                // the path to the xmi files must be specified from command line 
                List<File> trainAndDevFiles = new ArrayList<>();
                trainAndDevFiles.addAll(getTrainTextFiles(options.getBatchesDirectory()));
                trainAndDevFiles.addAll(getDevTextFiles(options.getBatchesDirectory()));
                trainAndDevFiles = toXMIFiles(options, trainAndDevFiles);

                // if path to additional train xmis is specified, add them to the training set
                if (options.getTrainXmiDir() != null) {
                    for (File trainXmiFile : options.getTrainXmiDir().listFiles()) {
                        trainAndDevFiles.add(trainXmiFile);
                    }
                }

                // now read the xmis we will use for evaluation
                List<File> testXmiFiles = new ArrayList<>();
                for (File testXmiFile : options.getTestXmiDir().listFiles()) {
                    testXmiFiles.add(testXmiFile);
                }
                params.stats = evaluation.trainAndTest(trainAndDevFiles, testXmiFiles);
                break;
            default:
                throw new IllegalArgumentException("Invalid EvaluateOn: " + options.getEvaluteOn());
            }
            scoredParams.put(params, params.stats.f1());
        }

        // print parameters sorted by F1
        List<ParameterSettings> list = new ArrayList<>(scoredParams.keySet());
        Function<ParameterSettings, Double> getCount = Functions.forMap(scoredParams);
        Collections.sort(list, Ordering.natural().onResultOf(getCount));

        // print performance of each set of parameters
        if (list.size() > 1) {
            System.err.println("Summary");
            for (ParameterSettings params : list) {
                System.err.printf("F1=%.3f P=%.3f R=%.3f %s\n", params.stats.f1(), params.stats.precision(),
                        params.stats.recall(), params);
            }
            System.err.println();
        }

        // print overall best model
        if (!list.isEmpty()) {
            ParameterSettings lastParams = list.get(list.size() - 1);
            System.err.println("Best model:");
            System.err.print(lastParams.stats);
            System.err.println(lastParams);
            System.err.println(lastParams.stats.confusions());
            System.err.println();
        }
    }

    public static class DocumentIDAnnotator extends JCasAnnotator_ImplBase {

        @Override
        public void process(JCas jCas) throws AnalysisEngineProcessException {
            String documentID = new File(ViewUriUtil.getURI(jCas)).getPath();
            DocumentID documentIDAnnotation = new DocumentID(jCas);
            documentIDAnnotation.setDocumentID(documentID);
            documentIDAnnotation.addToIndexes();
        }
    }

    public static class CopyDocumentTextToGoldView extends JCasAnnotator_ImplBase {
        @Override
        public void process(JCas jCas) throws AnalysisEngineProcessException {
            try {
                JCas goldView = jCas.getView(GOLD_VIEW_NAME);
                goldView.setDocumentText(jCas.getDocumentText());
            } catch (CASException e) {
                throw new AnalysisEngineProcessException(e);
            }
        }
    }
}