org.apache.ctakes.assertion.medfacts.cleartk.TrainAssertionModel.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.ctakes.assertion.medfacts.cleartk.TrainAssertionModel.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.ctakes.assertion.medfacts.cleartk;

import java.io.File;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.ctakes.assertion.eval.AssertionEvaluation;
import org.apache.ctakes.assertion.eval.AssertionEvaluation.ReferenceAnnotationsSystemAssertionClearer;
import org.apache.ctakes.assertion.eval.AssertionEvaluation.ReferenceIdentifiedAnnotationsSystemToGoldCopier;
import org.apache.ctakes.core.cc.XmiWriterCasConsumerCtakes;
import org.apache.log4j.Logger;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.collection.CollectionReader;
import org.apache.uima.fit.factory.AggregateBuilder;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.factory.CollectionReaderFactory;
import org.apache.uima.fit.factory.ConfigurationParameterFactory;
import org.apache.uima.fit.pipeline.SimplePipeline;
import org.apache.uima.fit.testing.util.HideOutput;
import org.cleartk.ml.CleartkAnnotator;
import org.cleartk.ml.DataWriterFactory;
import org.cleartk.ml.jar.DefaultDataWriterFactory;
import org.cleartk.ml.jar.DirectoryDataWriterFactory;
import org.cleartk.ml.jar.GenericJarClassifierFactory;
import org.cleartk.ml.opennlp.maxent.MaxentStringOutcomeDataWriter;
import org.cleartk.util.cr.XReader;
//import org.junit.Test;
//import edu.mayo.bmi.uima.core.type.textsem.EntityMention;

public class TrainAssertionModel {

    public static final String PARAM_NAME_DECODING_OUTPUT_DIRECTORY = "decoding-output-directory";

    public static final String PARAM_NAME_DECODING_INPUT_DIRECTORY = "decoding-input-directory";

    public static final String PARAM_NAME_TRAINING_INPUT_DIRECTORY = "training-input-directory";

    public static final String PARAM_NAME_MODEL_DIRECTORY = "model-directory";

    protected static final Logger logger = Logger.getLogger(TrainAssertionModel.class.getName());

    /**
     * @param args
     */
    /*
     public static void main(String[] args) {
        // TODO Auto-generated method stub
        String trainDir = args[0];
        String outputDir = args[1];
            
        try {
        CollectionReader reader = FilesCollectionReader.getCollectionReader(trainDir);
        AggregateBuilder builder = new AggregateBuilder();
        //builder.add(AnalysisEngineFactory.createEngineDescription("desc/AssertionMiniPipelineAnalysisEngine.xml", null));
        //builder.add(AnalysisEngineFactory.createEngineDescription(IdentifiedAnnotation.class));
        //builder.add(AnalysisEngineFactory.createEngineDescription("edu.mayo.bmi.uima.core.type.textsem.IdentifiedAnnotation"));
         builder.add(AssertionCleartkAnalysisEngine.getWriterDescription(outputDir));
         SimplePipeline.runPipeline(reader, builder.createEngineDescription());
         org.cleartk.classifier.jar.Train.main(outputDir);
        } catch (Exception e) {
     System.err.println("Exception: " + e);
     e.printStackTrace();
     throw new RuntimeException(e);
        }
            
             
     }
     */

    protected String modelOutputDirectory = "/work/medfacts/cleartk/data/train.model";

    //@Test
    public void testMaxent() throws Exception {

        String trainingDataDirectory = "/work/medfacts/cleartk/data/train";
        String evaluationDataDirectory = "/work/medfacts/cleartk/data/eval2.input";
        String evaluationOutputDataDirectory = "/work/medfacts/cleartk/data/eval2.output";

        String maxentModelOutputDirectory = modelOutputDirectory + "/maxent";
        AnalysisEngineDescription dataWriter = AnalysisEngineFactory.createEngineDescription(
                AssertionCleartkAnalysisEngine.class, AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
                DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
                MaxentStringOutcomeDataWriter.class.getName(), DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
                maxentModelOutputDirectory);
        testClassifier(dataWriter, maxentModelOutputDirectory, trainingDataDirectory, evaluationDataDirectory,
                evaluationOutputDataDirectory);

        //    // Not sure why the _SPLIT is here, but we will throw it out for good measure
        //    String firstLine = FileUtil.loadListOfStrings(new File(maxentDirectoryName
        //        + "/2008_Sichuan_earthquake.txt.pos"))[0].trim().replace("_SPLIT", "");
        //    checkPOS(firstLine);
    }

    public static void main(String args[]) {

        Options options = new Options();

        Option modelDirectoryOption = OptionBuilder.withLongOpt(TrainAssertionModel.PARAM_NAME_MODEL_DIRECTORY)
                .withArgName("DIR").hasArg().isRequired()
                .withDescription(
                        "the directory where the model is written to for training, or read from for decoding")
                .create();
        options.addOption(modelDirectoryOption);

        Option trainingInputDirectoryOption = OptionBuilder
                .withLongOpt(TrainAssertionModel.PARAM_NAME_TRAINING_INPUT_DIRECTORY).withArgName("DIR").hasArg()
                .isRequired().withDescription("directory where input training xmi files are located").create();
        options.addOption(trainingInputDirectoryOption);

        Option decodingInputDirectoryOption = OptionBuilder
                .withLongOpt(TrainAssertionModel.PARAM_NAME_DECODING_INPUT_DIRECTORY).withArgName("DIR").hasArg()
                .isRequired().withDescription("directory where input xmi files are located for decoding").create();
        options.addOption(decodingInputDirectoryOption);

        Option decodingOutputDirectoryOption = OptionBuilder
                .withLongOpt(TrainAssertionModel.PARAM_NAME_DECODING_OUTPUT_DIRECTORY).withArgName("DIR").hasArg()
                .isRequired()
                .withDescription("directory where output xmi files that are generated in decoding are placed")
                .create();
        options.addOption(decodingOutputDirectoryOption);

        CommandLineParser parser = new GnuParser();

        boolean invalidInput = false;

        CommandLine commandLine = null;
        String modelDirectory = null;
        String trainingInputDirectory = null;
        String decodingInputDirectory = null;
        String decodingOutputDirectory = null;
        try {
            commandLine = parser.parse(options, args);

            modelDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_MODEL_DIRECTORY);
            trainingInputDirectory = commandLine
                    .getOptionValue(TrainAssertionModel.PARAM_NAME_TRAINING_INPUT_DIRECTORY);
            decodingInputDirectory = commandLine
                    .getOptionValue(TrainAssertionModel.PARAM_NAME_DECODING_INPUT_DIRECTORY);
            decodingOutputDirectory = commandLine
                    .getOptionValue(TrainAssertionModel.PARAM_NAME_DECODING_OUTPUT_DIRECTORY);
        } catch (ParseException e) {
            invalidInput = true;
            logger.error("unable to parse command-line arguments", e);
        }

        if (modelDirectory == null || modelDirectory.isEmpty() || trainingInputDirectory == null
                || trainingInputDirectory.isEmpty() || decodingInputDirectory == null
                || decodingInputDirectory.isEmpty() || decodingOutputDirectory == null
                || decodingOutputDirectory.isEmpty()) {
            logger.error("required parameters not supplied");
            invalidInput = true;
        }

        if (invalidInput) {
            HelpFormatter formatter = new HelpFormatter();
            formatter.printHelp(TrainAssertionModel.class.getName(), options, true);
            return;
        }

        logger.info(String.format(
                "%n" + "model dir:           \"%s\"%n" + "training input dir:  \"%s\"%n"
                        + "decoding input dir:  \"%s\"%n" + "decoding output dir: \"%s\"%n",
                modelDirectory, trainingInputDirectory, decodingInputDirectory, decodingOutputDirectory));

        String maxentModelOutputDirectory = modelDirectory + "/maxent";
        try {
            AnalysisEngineDescription dataWriter = AnalysisEngineFactory.createEngineDescription(
                    AssertionCleartkAnalysisEngine.class, AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
                    DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
                    MaxentStringOutcomeDataWriter.class.getName(),
                    DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY, maxentModelOutputDirectory);
            testClassifier(dataWriter, maxentModelOutputDirectory, trainingInputDirectory, decodingInputDirectory,
                    decodingOutputDirectory);
        } catch (Exception e) {
            logger.error("Some exception happened while training or decoding...", e);
            return;
        }

    }

    public static void testClassifier(AnalysisEngineDescription dataWriter, String modelOutputDirectory,
            String trainingDataInputDirectory, String decodingInputDirectory, String decodingOutputDirectory,
            String... trainingArgs) throws Exception {

        CollectionReader trainingCollectionReader = CollectionReaderFactory.createReader(XReader.class,
                XReader.PARAM_ROOT_FILE, trainingDataInputDirectory, XReader.PARAM_XML_SCHEME, XReader.XMI);
        CollectionReader evaluationCollectionReader = CollectionReaderFactory.createReader(XReader.class,
                XReader.PARAM_ROOT_FILE, decodingInputDirectory, XReader.PARAM_XML_SCHEME, XReader.XMI);

        AggregateBuilder trainingBuilder = new AggregateBuilder();

        AnalysisEngineDescription goldCopierAnnotator = AnalysisEngineFactory
                .createEngineDescription(ReferenceIdentifiedAnnotationsSystemToGoldCopier.class);
        trainingBuilder.add(goldCopierAnnotator);

        AnalysisEngineDescription assertionAttributeClearerAnnotator = AnalysisEngineFactory
                .createEngineDescription(ReferenceAnnotationsSystemAssertionClearer.class);
        trainingBuilder.add(assertionAttributeClearerAnnotator);

        //    Class<? extends DataWriterFactory<String>> dataWriterFactoryClass = DefaultMaxentDataWriterFactory.class;
        AnalysisEngineDescription trainingAssertionAnnotator = AnalysisEngineFactory.createEngineDescription(
                AssertionCleartkAnalysisEngine.class, AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION);
        ConfigurationParameterFactory.addConfigurationParameters(trainingAssertionAnnotator,
                AssertionCleartkAnalysisEngine.PARAM_GOLD_VIEW_NAME, AssertionEvaluation.GOLD_VIEW_NAME,
                //        CleartkAnnotator.PARAM_DATA_WRITER_FACTORY_CLASS_NAME,
                //        dataWriterFactoryClass.getName(),
                DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY, modelOutputDirectory);
        trainingBuilder.add(trainingAssertionAnnotator);

        //    CollectionReader collectionReader = XReader.getCollectionReader(
        //        trainingDataDirectory);
        //    collectionReader.setConfigParameterValue(XReader.PARAM_XML_SCHEME, XReader.XMI);
        //    collectionReader.reconfigure();

        logger.info("starting feature generation...");
        SimplePipeline.runPipeline(trainingCollectionReader,
                //        FilesCollectionReader.getCollectionReaderWithView(
                //            "src/test/resources/data/treebank/11597317.tree",
                //            TreebankConstants.TREEBANK_VIEW)
                //        ,
                //        TreebankGoldAnnotator.getDescriptionPOSTagsOnly(),
                //        DefaultSnowballStemmer.getDescription("English"),
                //        dataWriter);
                trainingBuilder.createAggregateDescription());
        logger.info("finished feature generation.");

        String[] args;
        if (trainingArgs != null && trainingArgs.length > 0) {
            args = new String[trainingArgs.length + 1];
            args[0] = modelOutputDirectory;
            System.arraycopy(trainingArgs, 0, args, 1, trainingArgs.length);
        } else {
            args = new String[] { modelOutputDirectory };
        }

        HideOutput hider = new HideOutput();
        logger.info("starting training...");
        org.cleartk.ml.jar.Train.main(args);
        logger.info("finished training.");
        hider.restoreOutput();

        AggregateBuilder decodingBuilder = new AggregateBuilder();

        //AnalysisEngineDescription goldCopierAnnotator = AnalysisEngineFactory.createEngineDescription(ReferenceIdentifiedAnnotationsSystemToGoldCopier.class);
        decodingBuilder.add(goldCopierAnnotator);

        //AnalysisEngineDescription assertionAttributeClearerAnnotator = AnalysisEngineFactory.createEngineDescription(ReferenceAnnotationsSystemAssertionClearer.class);
        decodingBuilder.add(assertionAttributeClearerAnnotator);

        AnalysisEngineDescription decodingAssertionAnnotator = AnalysisEngineFactory.createEngineDescription(
                AssertionCleartkAnalysisEngine.class, AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION);
        ConfigurationParameterFactory.addConfigurationParameters(decodingAssertionAnnotator,
                AssertionCleartkAnalysisEngine.PARAM_GOLD_VIEW_NAME, AssertionEvaluation.GOLD_VIEW_NAME,
                GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
                new File(modelOutputDirectory, "model.jar").getPath());
        decodingBuilder.add(decodingAssertionAnnotator);

        //SimplePipeline.runPipeline(collectionReader,  builder.createEngineDescription());
        AnalysisEngineDescription decodingAggregateDescription = decodingBuilder.createAggregateDescription();

        //    AnalysisEngineDescription taggerDescription = AnalysisEngineFactory.createEngineDescription(
        //        AssertionCleartkAnalysisEngine.class,
        //        GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
        //        //AssertionComponents.TYPE_SYSTEM_DESCRIPTION,
        //        modelOutputDirectory + "/model.jar");

        logger.info("starting decoding...");
        SimplePipeline.runPipeline(evaluationCollectionReader,
                //        BreakIteratorAnnotatorFactory.createSentenceAnnotator(Locale.US),
                //        TokenAnnotator.getDescription(),
                //        DefaultSnowballStemmer.getDescription("English"),
                //taggerDescription,
                decodingAggregateDescription,
                AnalysisEngineFactory.createEngineDescription(XmiWriterCasConsumerCtakes.class,
                        AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
                        XmiWriterCasConsumerCtakes.PARAM_OUTPUTDIR, decodingOutputDirectory));
        logger.info("finished decoding.");

    }

}