edu.cmu.cs.lti.ark.fn.identification.training.AlphabetCreationThreaded.java Source code

Java tutorial

Introduction

Here is the source code for edu.cmu.cs.lti.ark.fn.identification.training.AlphabetCreationThreaded.java

Source

/*******************************************************************************
 * Copyright (c) 2011 Dipanjan Das 
 * Language Technologies Institute, 
 * Carnegie Mellon University, 
 * All Rights Reserved.
 *
 * AlphabetCreationThreaded.java is part of SEMAFOR 2.0.
 *
 * SEMAFOR 2.0 is free software: you can redistribute it and/or modify  it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or 
 * (at your option) any later version.
 *
 * SEMAFOR 2.0 is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU General Public License for more details. 
 *
 * You should have received a copy of the GNU General Public License along
 * with SEMAFOR 2.0.  If not, see <http://www.gnu.org/licenses/>.
 ******************************************************************************/
package edu.cmu.cs.lti.ark.fn.identification.training;

import com.google.common.base.Charsets;
import com.google.common.collect.*;
import com.google.common.io.Files;
import edu.cmu.cs.lti.ark.fn.data.prep.formats.AllLemmaTags;
import edu.cmu.cs.lti.ark.fn.data.prep.formats.Sentence;
import edu.cmu.cs.lti.ark.fn.identification.IdFeatureExtractor;
import edu.cmu.cs.lti.ark.fn.identification.RequiredDataForFrameIdentification;
import edu.cmu.cs.lti.ark.fn.utils.FNModelOptions;
import edu.cmu.cs.lti.ark.util.FileUtil;
import edu.cmu.cs.lti.ark.util.SerializedObjects;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.concurrent.*;
import java.util.logging.FileHandler;
import java.util.logging.LogManager;
import java.util.logging.Logger;
import java.util.logging.SimpleFormatter;

import static edu.cmu.cs.lti.ark.util.IntRanges.xrange;
import static org.apache.commons.io.IOUtils.closeQuietly;

public class AlphabetCreationThreaded {
    private static final Logger logger = Logger.getLogger(AlphabetCreationThreaded.class.getCanonicalName());
    private static final int DEFAULT_MINIMUM_FEATURE_COUNT = 2;
    public static final String ALPHABET_FILENAME = "alphabet.dat";

    private final Set<String> allFrames;
    private final String parseFile;
    private final String frameElementsFile;
    private final int startIndex;
    private final int endIndex;
    private final int numThreads;
    private final IdFeatureExtractor featureExtractor;

    /**
     * Parses commandline args, then creates a new {@link #AlphabetCreationThreaded} with them
     * and calls {@link #createAlphabet}
     *
     * @param args commandline arguments. see {@link #AlphabetCreationThreaded}
     *             for details.
     */
    public static void main(String[] args)
            throws IOException, ClassNotFoundException, ExecutionException, InterruptedException {
        final FNModelOptions options = new FNModelOptions(args);

        LogManager.getLogManager().reset();
        final FileHandler fileHandler = new FileHandler(options.logOutputFile.get(), true);
        fileHandler.setFormatter(new SimpleFormatter());
        logger.addHandler(fileHandler);

        final int startIndex = options.startIndex.get();
        final int endIndex = options.endIndex.get();
        logger.info("Start:" + startIndex + " end:" + endIndex);
        final RequiredDataForFrameIdentification r = SerializedObjects.readObject(options.fnIdReqDataFile.get());

        final int minimumCount = options.minimumCount.present() ? options.minimumCount.get()
                : DEFAULT_MINIMUM_FEATURE_COUNT;
        final int numThreads = options.numThreads.present() ? options.numThreads.get()
                : Runtime.getRuntime().availableProcessors();
        final File alphabetDir = new File(options.modelFile.get());
        final String featureExtractorType = options.idFeatureExtractorType.present()
                ? options.idFeatureExtractorType.get()
                : "basic";
        final IdFeatureExtractor featureExtractor = IdFeatureExtractor.fromName(featureExtractorType);
        final AlphabetCreationThreaded events = new AlphabetCreationThreaded(options.trainFrameElementFile.get(),
                options.trainParseFile.get(), r.getFrameMap().keySet(), featureExtractor, startIndex, endIndex,
                numThreads);
        final Multiset<String> unconjoinedFeatures = events.createAlphabet();
        final File alphabetFile = new File(alphabetDir, ALPHABET_FILENAME);
        events.conjoinAndWriteAlphabet(unconjoinedFeatures, minimumCount, alphabetFile);
    }

    /**
     * Creates a new AlphabetCreationThreaded with the given arguments
     *
     * @param frameElementsFile   path to file containing gold standard frame element
     *                            annotations
     * @param parseFile           path to file containing dependency parsed sentences (the same
    *                            ones that are in frameElementsFile
     * @param allFrames           set of all frame names
     * @param featureExtractor    feature extractor
     * @param startIndex          the line of the frameElementsFile to start at
     * @param endIndex            the line of frameElementsFile to end at
     * @param numThreads          the number of threads to run
     */
    public AlphabetCreationThreaded(String frameElementsFile, String parseFile, Set<String> allFrames,
            IdFeatureExtractor featureExtractor, int startIndex, int endIndex, int numThreads) {
        this.frameElementsFile = frameElementsFile;
        this.parseFile = parseFile;
        this.allFrames = allFrames;
        this.featureExtractor = featureExtractor;
        this.startIndex = startIndex;
        this.endIndex = endIndex;
        this.numThreads = numThreads;
    }

    /**
     * Splits frameElementLines into numThreads equally-sized batches and creates an alphabet
     * file for each one.
     *
     * @throws IOException
     */
    public Multiset<String> createAlphabet() throws IOException, ExecutionException, InterruptedException {
        final List<String> frameLines = Files.readLines(new File(frameElementsFile), Charsets.UTF_8)
                .subList(startIndex, endIndex);
        final int batchSize = (int) Math.ceil(frameLines.size() / (double) numThreads);
        final List<List<String>> frameLinesPartition = Lists.partition(frameLines, batchSize);
        final List<String> parseLines = Files.readLines(new File(parseFile), Charsets.UTF_8);
        final Multiset<String> alphabet = ConcurrentHashMultiset.create();
        final List<Callable<Integer>> jobs = Lists.newArrayListWithExpectedSize(numThreads);
        for (final int i : xrange(numThreads)) {
            jobs.add(newJob(i, frameLinesPartition.get(i), parseLines, alphabet));
        }
        final ExecutorService threadPool = Executors.newFixedThreadPool(numThreads);
        final List<Future<Integer>> results = threadPool.invokeAll(jobs);
        threadPool.shutdown();
        try {
            for (Integer i : xrange(results.size())) {
                logger.info(String.format("Thread %d successfully processed %d lines", i, results.get(i).get()));
            }
        } finally {
            threadPool.shutdownNow();
        }
        return alphabet;
    }

    private Callable<Integer> newJob(final int threadId, final List<String> frameLineBatch,
            final List<String> parseLines, final Multiset<String> alphabet) {
        return new Callable<Integer>() {
            public Integer call() {
                logger.info("Thread " + threadId + " : start");
                for (int i = 0; i < frameLineBatch.size() && !Thread.currentThread().isInterrupted(); i++) {
                    processLine(frameLineBatch.get(i), parseLines, alphabet);
                    if (i % 50 == 0) {
                        logger.info("Thread " + i + "\n" + "Processed index:" + i + " of " + frameLineBatch.size()
                                + "\n" + "Alphabet size:" + alphabet.elementSet().size());
                    }
                }
                logger.info("Thread " + threadId + " : end");
                return frameLineBatch.size();
            }
        };
    }

    private void processLine(String frameLine, List<String> parseLines, Multiset<String> alphabet) {
        // Parse the frameLine
        final String[] toks = frameLine.split("\t");
        // throw out first two fields
        final List<String> tokens = Arrays.asList(toks).subList(2, toks.length);
        //final String frameName = tokens.get(1);
        final String[] targetIdxsStr = tokens.get(3).split("_");
        final int sentNum = Integer.parseInt(tokens.get(5));

        final int[] targetTokenIdxs = new int[targetIdxsStr.length];
        for (int j = 0; j < targetIdxsStr.length; j++)
            targetTokenIdxs[j] = Integer.parseInt(targetIdxsStr[j]);
        Arrays.sort(targetTokenIdxs);

        // Parse the parse line
        final String parseLine = parseLines.get(sentNum);
        final Sentence sentence = Sentence.fromAllLemmaTagsArray(AllLemmaTags.readLine(parseLine));

        // extract base features (not conjoined with frame names) for every frame
        alphabet.addAll(featureExtractor.getBaseFeatures(targetTokenIdxs, sentence).keySet());
    }

    public static BiMap<String, Integer> readAlphabetFile(File file) throws IOException {
        final BufferedReader bReader = Files.newReader(file, Charsets.UTF_8);
        try {
            final BiMap<String, Integer> alphabet = HashBiMap.create();
            String line;
            int i = 0;
            while ((line = bReader.readLine()) != null) {
                final String[] fields = line.trim().split("\t");
                alphabet.put(fields[0], i);
                i++;
            }
            return alphabet;
        } finally {
            closeQuietly(bReader);
        }
    }

    /** Gets the number of features in the model stored in alphabetFile */
    public static int getAlphabetSize(String alphabetFile) throws IOException {
        return FileUtil.countLines(alphabetFile);
    }

    private void conjoinAndWriteAlphabet(final Multiset<String> unconjoinedFeatures, final int minimumCount,
            File alphabetFile) throws IOException {
        final BufferedWriter output = Files.newWriter(alphabetFile, Charsets.UTF_8);
        final int unconjoinedSize = unconjoinedFeatures.elementSet().size();
        try {
            logger.info("Writing alphabet.");
            int numUnconjoined = 0;
            int numConjoined = 0;
            for (String unconjoinedFeature : unconjoinedFeatures.elementSet()) {
                if (unconjoinedFeatures.count(unconjoinedFeature) >= minimumCount) {
                    final Set<String> conjoinedFeatureNames = featureExtractor.getConjoinedFeatureNames(allFrames,
                            unconjoinedFeature);
                    numConjoined += conjoinedFeatureNames.size();
                    for (String feature : conjoinedFeatureNames) {
                        output.write(String.format("%s\n", feature));
                    }
                }
                numUnconjoined++;
                if (numUnconjoined % 50 == 0) {
                    logger.info("Unconjoined: " + numUnconjoined + " of " + unconjoinedSize);
                    logger.info("Conjoined: " + numConjoined);
                }
            }
            logger.info("Done writing alphabet.");
        } finally {
            closeQuietly(output);
        }
    }
}