edu.cmu.cs.lti.ark.fn.identification.latentmodel.TrainBatchModelDerThreaded.java Source code

Java tutorial

Introduction

Here is the source code for edu.cmu.cs.lti.ark.fn.identification.latentmodel.TrainBatchModelDerThreaded.java

Source

/*******************************************************************************
 * Copyright (c) 2011 Dipanjan Das 
 * Language Technologies Institute, 
 * Carnegie Mellon University, 
 * All Rights Reserved.
 *
 * TrainBatchModelDerThreaded.java is part of SEMAFOR 2.0.
 *
 * SEMAFOR 2.0 is free software: you can redistribute it and/or modify  it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or 
 * (at your option) any later version.
 *
 * SEMAFOR 2.0 is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU General Public License for more details. 
 *
 * You should have received a copy of the GNU General Public License along
 * with SEMAFOR 2.0.  If not, see <http://www.gnu.org/licenses/>.
 ******************************************************************************/
package edu.cmu.cs.lti.ark.fn.identification.latentmodel;

import com.google.common.base.Charsets;
import com.google.common.base.Optional;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.Lists;
import com.google.common.io.Files;
import com.google.common.primitives.Ints;
import edu.cmu.cs.lti.ark.fn.optimization.Lbfgs;
import edu.cmu.cs.lti.ark.fn.utils.FNModelOptions;
import edu.cmu.cs.lti.ark.fn.utils.ThreadPool;
import edu.cmu.cs.lti.ark.util.SerializedObjects;

import java.io.*;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.logging.FileHandler;
import java.util.logging.LogManager;
import java.util.logging.Logger;
import java.util.logging.SimpleFormatter;

import static org.apache.commons.io.IOUtils.closeQuietly;

public class TrainBatchModelDerThreaded {
    private static final Logger logger = Logger.getLogger(TrainBatchModelDerThreaded.class.getCanonicalName());

    private static final int BATCH_SIZE = 10;
    private static final boolean CACHE_FEATURES = false;
    private static final String FEATURE_FILENAME_PREFIX = "feats_";
    private static final String FEATURE_FILENAME_SUFFIX = ".jobj.gz";
    private static final FilenameFilter featureFilenameFilter = new FilenameFilter() {
        public boolean accept(File dir, String name) {
            return name.startsWith(FEATURE_FILENAME_PREFIX) && name.endsWith(FEATURE_FILENAME_SUFFIX);
        }
    };
    private static final Comparator<String> featureFilenameComparator = new Comparator<String>() {
        private final int PREFIX_LEN = FEATURE_FILENAME_PREFIX.length();
        private final int SUFFIX_LEN = FEATURE_FILENAME_SUFFIX.length();

        public int compare(String a, String b) {
            final int aNum = Integer.parseInt(a.substring(PREFIX_LEN, a.length() - SUFFIX_LEN));
            final int bNum = Integer.parseInt(b.substring(PREFIX_LEN, b.length() - SUFFIX_LEN));
            return Ints.compare(aNum, bNum);
        }
    };
    private final double[] params;
    private final List<String> eventFiles;
    private final String modelFile;
    private final int modelSize;
    private final boolean useL2Regularization;
    private final int numThreads;
    private final double lambda;
    private double[] gradients;
    private double[][] tGradients;
    private double[] tValues;
    // only used if CACHE_FEATURES is true
    private final LoadingCache<Integer, int[][][]> featuresByTargetIdx = CacheBuilder.newBuilder()
            .initialCapacity(16000) // > number of targets in training data
            .build(new CacheLoader<Integer, int[][][]>() {
                @Override
                public int[][][] load(Integer key) throws IOException, ClassNotFoundException {
                    return loadFeaturesForTarget(key);
                }
            });

    public static void main(String[] args) throws Exception {
        final FNModelOptions options = new FNModelOptions(args);
        LogManager.getLogManager().reset();
        final FileHandler fh = new FileHandler(options.logOutputFile.get(), true);
        fh.setFormatter(new SimpleFormatter());
        logger.addHandler(fh);

        final String restartFile = options.restartFile.get();
        final int numThreads = options.numThreads.present() ? options.numThreads.get()
                : Runtime.getRuntime().availableProcessors();
        final TrainBatchModelDerThreaded tbm = new TrainBatchModelDerThreaded(options.alphabetFile.get(),
                options.eventsFile.get(), options.modelFile.get(), options.reg.get(), options.lambda.get(),
                restartFile.equals("null") ? Optional.<String>absent() : Optional.of(restartFile), numThreads);
        tbm.trainModel();
    }

    public TrainBatchModelDerThreaded(String alphabetFile, String eventsDir, String modelFile, String reg,
            double lambda, Optional<String> restartFile, int numThreads) throws IOException {
        this.modelSize = getNumFeatures(alphabetFile);
        this.modelFile = modelFile;
        this.eventFiles = getEventFiles(new File(eventsDir));
        this.useL2Regularization = reg.equals("reg");
        this.lambda = lambda / (double) eventFiles.size();
        this.numThreads = numThreads;
        this.params = restartFile.isPresent() ? loadModel(restartFile.get()) : getInitialParams();
    }

    public void trainModel() throws Exception {
        // minimum verbosity
        int[] iprint = new int[] { Lbfgs.DEBUG ? 1 : -1, 0 };
        // lbfgs sets this flag to zero when it has converged
        int[] iflag = new int[] { 0 };
        tGradients = new double[numThreads][modelSize]; // gradients for each batch/thread
        gradients = new double[modelSize]; // gradients of all threads added together
        tValues = new double[numThreads];
        int iteration = 0;
        do {
            logger.info("Starting iteration:" + iteration);
            Arrays.fill(gradients, 0.0);
            double m_value = getValuesAndGradients();
            logger.info("Function value:" + m_value);
            riso.numerical.LBFGS.lbfgs(modelSize, Lbfgs.NUM_CORRECTIONS, params, m_value, gradients, false,
                    new double[modelSize], iprint, Lbfgs.STOPPING_THRESHOLD, Lbfgs.XTOL, iflag);
            logger.info("Finished iteration:" + iteration);
            iteration++;
            if (iteration % Lbfgs.SAVE_EVERY_K == 0) {
                saveModel(params, modelFile + "_" + iteration);
            }
        } while (iteration <= Lbfgs.MAX_ITERATIONS && iflag[0] != 0);
        saveModel(params, modelFile);
    }

    private double getValuesAndGradients() {
        double value = 0.0;
        for (double[] tGradient : tGradients)
            Arrays.fill(tGradient, 0.0);
        Arrays.fill(tValues, 0.0);

        final ThreadPool threadPool = new ThreadPool(numThreads);
        int task = 0;
        for (int i = 0; i < eventFiles.size(); i += BATCH_SIZE) {
            final int taskId = task;
            final int start = i;
            final int end = Math.min(i + BATCH_SIZE, eventFiles.size());
            threadPool.runTask(new Runnable() {
                public void run() {
                    logger.info("Task " + taskId + " : start");
                    try {
                        processBatch(taskId, start, end);
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                    logger.info("Task " + taskId + " : end");
                }
            });
            task++;
        }
        threadPool.join();
        // merge the calculated tValues and gradients together
        Arrays.fill(gradients, 0.0);
        for (int i = 0; i < numThreads; i++) {
            value += tValues[i];
            for (int j = 0; j < modelSize; j++) {
                gradients[j] += tGradients[i][j];
            }
        }
        System.out.println("Finished value and gradient computation.");
        return value;
    }

    public void processBatch(int taskId, int start, int end)
            throws ExecutionException, IOException, ClassNotFoundException {
        int threadId = taskId % numThreads;
        logger.info("Processing batch:" + taskId + " thread ID:" + threadId);
        for (int targetIdx = start; targetIdx < end; targetIdx++) {
            int[][][] featureArray = getFeaturesForTarget(targetIdx);
            int featArrLen = featureArray.length;
            double exp[][] = new double[featArrLen][];
            double sumExp[] = new double[featArrLen];
            double totalExp = 0.0;
            for (int i = 0; i < featArrLen; i++) {
                exp[i] = new double[featureArray[i].length];
                sumExp[i] = 0.0; // the partition function
                for (int j = 0; j < exp[i].length; j++) {
                    double weiFeatSum = 0.0;
                    int[] feats = featureArray[i][j];
                    for (int feat : feats) {
                        weiFeatSum += params[feat];
                    }
                    exp[i][j] = Math.exp(weiFeatSum);
                    sumExp[i] += exp[i][j];
                }
                totalExp += sumExp[i];
            }
            tValues[threadId] -= Math.log(sumExp[0]) - Math.log(totalExp);
            for (int i = 0; i < featArrLen; i++) {
                for (int j = 0; j < exp[i].length; j++) {
                    double Y = 0.0;
                    if (i == 0) {
                        // the correct frame is always first
                        Y = exp[i][j] / sumExp[i];
                    }
                    double YMinusP = Y - (exp[i][j] / totalExp);
                    int[] feats = featureArray[i][j];
                    for (int feat : feats) {
                        tGradients[threadId][feat] -= YMinusP;
                    }
                }
            }
            if (targetIdx % 100 == 0) {
                System.out.print(".");
                logger.info("" + targetIdx);
            }
        }
        if (useL2Regularization) {
            for (int i = 0; i < params.length; ++i) {
                final double weight = Math.log(params[i]);
                tValues[threadId] += lambda * (weight * weight);
                tGradients[threadId][i] += 2 * lambda * weight;
            }
        }
    }

    private int[][][] getFeaturesForTarget(int targetIdx)
            throws ExecutionException, IOException, ClassNotFoundException {
        if (CACHE_FEATURES) {
            return featuresByTargetIdx.get(targetIdx);
        } else {
            return loadFeaturesForTarget(targetIdx);
        }
    }

    private int[][][] loadFeaturesForTarget(Integer key) throws IOException, ClassNotFoundException {
        return SerializedObjects.readObject(eventFiles.get(key));
    }

    private double[] getInitialParams() {
        final double[] params = new double[modelSize];
        Arrays.fill(params, 1.0);
        return params;
    }

    /**
     * Reads parameters from modelFile, one param per line.
     *
     * @param modelFile the filename to read from
     */
    private double[] loadModel(String modelFile) throws IOException {
        final List<String> lines = Files.readLines(new File(modelFile), Charsets.UTF_8);
        final double[] params = new double[lines.size()];
        int i = 0;
        for (String line : lines) {
            params[i] = Double.parseDouble(line.trim());
            i++;
        }
        return params;
    }

    /**
     * Writes the current parameters to modelFile, one param per line.
     *
     * @param modelFile the filename to write to
     */
    private static void saveModel(double[] params, String modelFile) throws IOException {
        final BufferedWriter bWriter = new BufferedWriter(new FileWriter(modelFile));
        try {
            for (double param : params)
                bWriter.write(param + "\n");
        } finally {
            closeQuietly(bWriter);
        }
    }

    /** Reads the number of features in the model from the first line of alphabetFile */
    private static int getNumFeatures(String alphabetFile) throws IOException {
        final String firstLine = Files.readFirstLine(new File(alphabetFile), Charsets.UTF_8);
        final int numFeatures = Integer.parseInt(firstLine.trim()) + 1;
        logger.info("Number of features: " + numFeatures);
        return numFeatures;
    }

    /** Gets the list of all feature files */
    private static List<String> getEventFiles(File eventsDir) {
        final String[] files = eventsDir.list(featureFilenameFilter);
        Arrays.sort(files, featureFilenameComparator);
        final List<String> eventFiles = Lists.newArrayListWithExpectedSize(files.length);
        for (String file : files)
            eventFiles.add(eventsDir.getAbsolutePath() + "/" + file);
        logger.info("Total number of datapoints:" + eventFiles.size());
        return eventFiles;
    }
}