com.ibm.watson.developer_cloud.professor_languo.pipeline.QuestionSetManager.java Source code

Java tutorial

Introduction

Here is the source code for com.ibm.watson.developer_cloud.professor_languo.pipeline.QuestionSetManager.java

Source

/*
 * Copyright IBM Corp. 2015
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 */

package com.ibm.watson.developer_cloud.professor_languo.pipeline;

import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Properties;
import java.util.Random;

import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVPrinter;
import org.apache.commons.csv.CSVRecord;
import org.apache.commons.io.FilenameUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import com.ibm.watson.developer_cloud.professor_languo.configuration.ConfigurationConstants;
import com.ibm.watson.developer_cloud.professor_languo.configuration.Messages;
import com.ibm.watson.developer_cloud.professor_languo.data_model.QuestionAnswerSet;
import com.ibm.watson.developer_cloud.professor_languo.data_model.QuestionAnswerSet.CorrectAnswer;
import com.ibm.watson.developer_cloud.professor_languo.exception.IngestionException;
import com.ibm.watson.developer_cloud.professor_languo.exception.PipelineException;
import com.ibm.watson.developer_cloud.professor_languo.ingestion.indexing.StackExchangeThreadSerializer;
import com.ibm.watson.developer_cloud.professor_languo.model.stack_exchange.CorpusBuilder;
import com.ibm.watson.developer_cloud.professor_languo.model.stack_exchange.StackExchangeConstants;
import com.ibm.watson.developer_cloud.professor_languo.model.stack_exchange.StackExchangeQuestion;
import com.ibm.watson.developer_cloud.professor_languo.model.stack_exchange.StackExchangeThread;

/**
 * A QuestionSetManager is responsible for parsing the tab-separated values (TSV) file containing
 * Stack Exchange threads that have been marked as duplicates, generated by {@link CorpusBuilder},
 * to produce training, test, and validation question sets.
 *
 */
public class QuestionSetManager {

    private final static Logger logger = LogManager.getLogger(QuestionSetManager.class.getName());

    private final QuestionAnswerSet trainingSet, testSet, validationSet;

    private Random rng;

    /**
     * Create a new {@link QuestionSetManager} that will create training, test, and validation sets
     * from a TSV file containing duplicate {@link StackExchangeThread} questions, i.e., questions
     * that have been marked as duplicates of other {@link StackExchangeThread} questions.
     * <p>
     * NOTE: This constructor uses a static seed for the pseudo-random number generator, and thus will
     * always generate the same partitioning for fixed input arguments. To generate variable
     * partitions, provide a seed using the constructor
     * {@link QuestionSetManager#QuestionSetManager(String, long, double[])}
     * 
     * @param duplicateQuestionTsvFilePath - The absolute file path of the duplicate question TSV file
     *        to be parsed
     * @param trainTestValidateFractions - A three-element array representing the relative proportions
     *        of the training, test, and validation subsets. For example, <code>[0.7, 0.2, 0.1]</code>
     *        would indicate that (in expectation) 70% of the duplicate questions should belong to the
     *        training set, 20% should belong to the test set, and 10% should belong to the validation
     *        set
     * @throws PipelineException
     */
    public QuestionSetManager(String duplicateQuestionTsvFilePath, double[] trainTestValidateFractions)
            throws PipelineException {
        this(duplicateQuestionTsvFilePath, 525600, trainTestValidateFractions);
    }

    /**
     * Create a new {@link QuestionSetManager} that will create training, test, and validation sets
     * from a TSV file containing duplicate {@link StackExchangeThread} questions, i.e., questions
     * that have been marked as duplicates of other {@link StackExchangeThread} questions
     * 
     * @param duplicateQuestionTsvFilePath - The absolute file path of the duplicate question TSV file
     *        to be parsed
     * @param seed - A seed for the random number generator used to provide a pseudo-random
     *        partitioning of the data
     * @param trainTestValidateFractions - A three-element array representing the relative proportions
     *        of the training, test, and validation subsets. For example, <code>[0.7, 0.2, 0.1]</code>
     *        would indicate that (in expectation) 70% of the duplicate questions should belong to the
     *        training set, 20% should belong to the test set, and 10% should belong to the validation
     *        set
     * @throws PipelineException
     */
    public QuestionSetManager(String duplicateQuestionTsvFilePath, long seed, double[] trainTestValidateFractions)
            throws PipelineException {
        // Begin by initializing member variables
        this.rng = new Random(seed);
        this.trainingSet = new QuestionAnswerSet(duplicateQuestionTsvFilePath);
        this.testSet = new QuestionAnswerSet(duplicateQuestionTsvFilePath);
        this.validationSet = new QuestionAnswerSet(duplicateQuestionTsvFilePath);

        // Next, normalize the train/test/validate fractions
        double[] cumulativeProbabilities = checkAndNormalizeFractions(trainTestValidateFractions);

        // Finally, parse the duplicate question TSV file, bucketing each entry
        // into
        // one of the three sets
        try {
            if (duplicateQuestionTsvFilePath == null)
                throw new IllegalArgumentException(Messages.getString("RetrieveAndRank.DUPLICATE_TSV_PATH")); //$NON-NLS-1$
            File dupQuestionFile = new File(duplicateQuestionTsvFilePath);
            if (!dupQuestionFile.exists() || !dupQuestionFile.canRead())
                throw new IOException(
                        MessageFormat.format(Messages.getString("RetrieveAndRank.DUPLICATE_TSV_MISSING"), //$NON-NLS-1$
                                duplicateQuestionTsvFilePath));
            parseTsvAndPartitionRecords(dupQuestionFile, cumulativeProbabilities);
        } catch (IOException | PipelineException e) {
            logger.fatal(e.getMessage());
            throw new PipelineException(e);
        }
    }

    /**
     * Create a new {@link QuestionSetManager} that will create training, test, and validation sets
     * from a TSV file containing duplicate {@link StackExchangeThread} questions, i.e., questions
     * that have been marked as duplicates of other {@link StackExchangeThread} questions
     * 
     * @param properties - A {@link Properties} object with the following properties: <br>
     *        - {@link ConfigurationConstants#DUPLICATE_THREAD_TSV_PATH} [required] : The path to the
     *        folder containing a duplicate thread TSV file<br>
     *        - {@link ConfigurationConstants#QUESTION_SET_MANAGER_PARTITION_FRACTIONS} [required] : A
     *        3-element comma-separated array representing the relative proportions of the training,
     *        test, and validation subsets. For example, <code>[0.7, 0.2, 0.1]</code> would indicate
     *        that (in expectation) 70% of the duplicate questions should belong to the training set,
     *        20% should belong to the test set, and 10% should belong to the validation set<br>
     *        - {@link ConfigurationConstants#QUESTION_SET_MANAGER_RAND_NUM_SEED} [optional] : A seed
     *        for the random number generator used to provide a pseudo-random partitioning of the data
     * 
     * @throws PipelineException
     */
    public static QuestionSetManager newInstance(Properties properties) throws PipelineException {
        String duplicateQuestionTsvFilePath = properties
                .getProperty(ConfigurationConstants.DUPLICATE_THREAD_TSV_PATH);
        if (duplicateQuestionTsvFilePath == null)
            throw new PipelineException(MessageFormat.format(Messages.getString("RetrieveAndRank.MISSING_PROPERTY"), //$NON-NLS-1$
                    ConfigurationConstants.DUPLICATE_THREAD_TSV_PATH));

        long seed = Long.valueOf(
                properties.getProperty(ConfigurationConstants.QUESTION_SET_MANAGER_RAND_NUM_SEED, "525600"));

        String fractions = properties.getProperty(ConfigurationConstants.QUESTION_SET_MANAGER_PARTITION_FRACTIONS);
        if (fractions == null)
            throw new PipelineException(MessageFormat.format(Messages.getString("RetrieveAndRank.MISSING_PROPERTY"), //$NON-NLS-1$
                    ConfigurationConstants.QUESTION_SET_MANAGER_PARTITION_FRACTIONS));

        String[] fractionArray = fractions.replaceAll("(\\[|\\])", "").split(",");
        double[] trainTestValidateFractions = new double[fractionArray.length];
        int i = 0;
        for (String fraction : fractionArray) {
            trainTestValidateFractions[i] = Double.valueOf(fraction.trim());
            i++;
        }

        return new QuestionSetManager(duplicateQuestionTsvFilePath, seed, trainTestValidateFractions);
    }

    /**
     * This function ensures that the input array is a three-element, non-negative array, and returns
     * a three-element cumulative density function (CDF)
     * 
     * @param trainTestValidateFractions - The relative fraction of questions that should be allocated
     *        to train, test, and validation sets
     * 
     * @return A three-element cumulative density function (CDF) representing the cumulative
     *         probability that a given duplicate thread belongs to the training, test, or validation
     *         set respectively
     */
    private double[] checkAndNormalizeFractions(double[] trainTestValidateFractions) {
        // Validate input
        if (trainTestValidateFractions == null || trainTestValidateFractions.length != 3)
            throw new IllegalArgumentException(Messages.getString("RetrieveAndRank.TEST_FRACTIONS_THREE")); //$NON-NLS-1$
        else {
            for (double f : trainTestValidateFractions) {
                if (f < 0) {
                    throw new IllegalArgumentException(
                            Messages.getString("RetrieveAndRank.TEST_FRACTIONS_NONNEGATIVE")); //$NON-NLS-1$
                }
            }
        }

        // Tally up the sum of the three fractional values...
        double sum = 0;
        for (double f : trainTestValidateFractions)
            sum += f;

        // ...and normalize
        double[] cumulativeProbabilities = new double[trainTestValidateFractions.length];
        for (int i = 0; i < trainTestValidateFractions.length; i++) {
            if (i == 0)
                cumulativeProbabilities[i] = trainTestValidateFractions[i] / sum;
            else
                cumulativeProbabilities[i] = cumulativeProbabilities[i - 1] + (trainTestValidateFractions[i] / sum);
        }

        return cumulativeProbabilities;
    }

    /**
     * This function is responsible for parsing a duplicate Stack Exchange thread TSV file produced by
     * {@link StackExchangeThreadSerializer}, and partitioning each such thread into the training set,
     * test set, or validation set. In addition, the corresponding row of the TSV file will be written
     * out to a training-, test-, or validation-set-specific TSV file in the same directory as the
     * input TSV file.
     * 
     * @param dupQuestionFile - A TSV file containing duplicate {@link StackExchangeThread} records
     * @param trainTestValidateCumulativeProbs - A CDF of the desired proportion of training, test,
     *        and validation set records
     * @throws PipelineException
     */
    private void parseTsvAndPartitionRecords(File dupQuestionFile, double[] trainTestValidateCumulativeProbs)
            throws PipelineException {
        // Open the TSV file for parsing, and CSVPrinters for outputting train,
        // test, and validation set
        // TSV files
        String baseName = FilenameUtils.removeExtension(dupQuestionFile.getAbsolutePath());
        String extension = FilenameUtils.getExtension(dupQuestionFile.getAbsolutePath());
        try (FileReader reader = new FileReader(dupQuestionFile);
                CSVPrinter trainSetPrinter = new CSVPrinter(
                        new FileWriter(baseName + StackExchangeConstants.DUP_THREAD_TSV_TRAIN_FILE_SUFFIX
                                + FilenameUtils.EXTENSION_SEPARATOR + extension),
                        CSVFormat.TDF.withHeader(CorpusBuilder.getTsvColumnHeaders()));
                CSVPrinter testSetPrinter = new CSVPrinter(
                        new FileWriter(baseName + StackExchangeConstants.DUP_THREAD_TSV_TEST_FILE_SUFFIX
                                + FilenameUtils.EXTENSION_SEPARATOR + extension),
                        CSVFormat.TDF.withHeader(CorpusBuilder.getTsvColumnHeaders()));
                CSVPrinter validationSetPrinter = new CSVPrinter(
                        new FileWriter(baseName + StackExchangeConstants.DUP_THREAD_TSV_VALIDATE_FILE_SUFFIX
                                + FilenameUtils.EXTENSION_SEPARATOR + extension),
                        CSVFormat.TDF.withHeader(CorpusBuilder.getTsvColumnHeaders()))) {

            // Parse the duplicate thread TSV file
            CSVParser parser = CSVFormat.TDF.withHeader().parse(reader);

            // Iterate over each CSV record, and place into a desired partition
            // (train, test, or
            // validation)
            Iterator<CSVRecord> recordIterator = parser.iterator();
            while (recordIterator.hasNext()) {
                CSVRecord record = recordIterator.next();

                // Get the StackExchangeThread associated with this record, and
                // create a question from it
                StackExchangeThread duplicateThread = StackExchangeThreadSerializer.deserializeThreadFromBinFile(
                        record.get(CorpusBuilder.TSV_COL_HEADER_SERIALIZED_FILE_PATH));
                StackExchangeQuestion duplicateQuestion = new StackExchangeQuestion(duplicateThread);
                String parentId = record.get(CorpusBuilder.TSV_COL_HEADER_PARENT_ID);

                // Now drop this question into a partition, and write it to a
                // corresponding TSV file
                double p = rng.nextDouble(); // Random number determines
                // partition for this record
                if (p <= trainTestValidateCumulativeProbs[0]) {
                    // This record goes in the training set
                    if (!addQuestionToSet(duplicateQuestion, parentId, this.trainingSet)) {
                        throw new PipelineException(
                                MessageFormat.format(Messages.getString("RetrieveAndRank.TRAINING_SET_FAILED_Q"), //$NON-NLS-1$
                                        duplicateThread.getId()));
                    }
                    trainSetPrinter.printRecord((Object[]) convertRecordToArray(record));
                } else if (p <= trainTestValidateCumulativeProbs[1]) {
                    // This record goes in the test set
                    if (!addQuestionToSet(duplicateQuestion, parentId, this.testSet)) {
                        throw new PipelineException(
                                MessageFormat.format(Messages.getString("RetrieveAndRank.TEST_SET_FAILED_Q"), //$NON-NLS-1$
                                        duplicateThread.getId()));
                    }
                    testSetPrinter.printRecord((Object[]) convertRecordToArray(record));
                } else {
                    // This record goes in the validation set
                    assert (p <= trainTestValidateCumulativeProbs[2]);
                    if (!addQuestionToSet(duplicateQuestion, parentId, this.validationSet)) {
                        throw new PipelineException(
                                MessageFormat.format(Messages.getString("RetrieveAndRank.VALIDATION_SET_FAILED_Q"), //$NON-NLS-1$
                                        duplicateThread.getId()));
                    }
                    validationSetPrinter.printRecord((Object[]) convertRecordToArray(record));
                }
            }

            // Flush all the printers prior to closing
            trainSetPrinter.flush();
            testSetPrinter.flush();
            validationSetPrinter.flush();
        } catch (IOException | IngestionException e) {
            throw new PipelineException(e);
        }
    }

    /**
     * @param record - A single {@link CSVRecord} from the duplicate thread TSV file
     * @return A string array representing the data in each column of the record
     */
    private String[] convertRecordToArray(CSVRecord record) {
        String[] recordArray = new String[CorpusBuilder.getTsvColumnHeaders().length];
        for (int i = 0; i < CorpusBuilder.getTsvColumnHeaders().length; i++)
            recordArray[i] = record.get(i);
        return recordArray;
    }

    /**
     * Adds a new {@link StackExchangeQuestion} to a chosen {@link QuestionAnswerSet}
     * 
     * @param duplicateQuestion - A {@link StackExchangeQuestion} for the duplicate question
     * @param parentId - The ID of the thread that this question duplicates
     * @param set - The training, test, or validation {@link QuestionAnswerSet} to which this question
     *        is being added
     * 
     * @return <code>true</code> if this question was added to the set, <code>false</code> otherwise
     */
    private boolean addQuestionToSet(StackExchangeQuestion duplicateQuestion, String parentId,
            QuestionAnswerSet set) {
        // Create a new correct answer to add to the ground truth for this
        // question. The
        // parent ID (i.e., the ID of the thread that this thread duplicates) is
        // the answer
        // text that will be used in comparisons to determine whether a
        // candidate answer
        // is correct or not
        CorrectAnswer c = new CorrectAnswer(parentId);

        // Now add this question, and its answer, to the QA set
        return set.addQuestionAnswers(duplicateQuestion, Arrays.asList(c), null);
    }

    /**
     * @return The {@link QuestionAnswerSet} training data set
     */
    public QuestionAnswerSet getTrainingSet() {
        return this.trainingSet;
    }

    /**
     * @return The {@link QuestionAnswerSet} test data set
     */
    public QuestionAnswerSet getTestSet() {
        return this.testSet;
    }

    /**
     * @return The {@link QuestionAnswerSet} validation data set
     */
    public QuestionAnswerSet getValidationSet() {
        return this.validationSet;
    }

}