Java tutorial
/* * Copyright IBM Corp. 2015 * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ package com.ibm.watson.developer_cloud.professor_languo.pipeline; import java.io.File; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.text.MessageFormat; import java.util.Arrays; import java.util.Iterator; import java.util.Properties; import java.util.Random; import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.CSVParser; import org.apache.commons.csv.CSVPrinter; import org.apache.commons.csv.CSVRecord; import org.apache.commons.io.FilenameUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import com.ibm.watson.developer_cloud.professor_languo.configuration.ConfigurationConstants; import com.ibm.watson.developer_cloud.professor_languo.configuration.Messages; import com.ibm.watson.developer_cloud.professor_languo.data_model.QuestionAnswerSet; import com.ibm.watson.developer_cloud.professor_languo.data_model.QuestionAnswerSet.CorrectAnswer; import com.ibm.watson.developer_cloud.professor_languo.exception.IngestionException; import com.ibm.watson.developer_cloud.professor_languo.exception.PipelineException; import com.ibm.watson.developer_cloud.professor_languo.ingestion.indexing.StackExchangeThreadSerializer; import com.ibm.watson.developer_cloud.professor_languo.model.stack_exchange.CorpusBuilder; import com.ibm.watson.developer_cloud.professor_languo.model.stack_exchange.StackExchangeConstants; import com.ibm.watson.developer_cloud.professor_languo.model.stack_exchange.StackExchangeQuestion; import com.ibm.watson.developer_cloud.professor_languo.model.stack_exchange.StackExchangeThread; /** * A QuestionSetManager is responsible for parsing the tab-separated values (TSV) file containing * Stack Exchange threads that have been marked as duplicates, generated by {@link CorpusBuilder}, * to produce training, test, and validation question sets. * */ public class QuestionSetManager { private final static Logger logger = LogManager.getLogger(QuestionSetManager.class.getName()); private final QuestionAnswerSet trainingSet, testSet, validationSet; private Random rng; /** * Create a new {@link QuestionSetManager} that will create training, test, and validation sets * from a TSV file containing duplicate {@link StackExchangeThread} questions, i.e., questions * that have been marked as duplicates of other {@link StackExchangeThread} questions. * <p> * NOTE: This constructor uses a static seed for the pseudo-random number generator, and thus will * always generate the same partitioning for fixed input arguments. To generate variable * partitions, provide a seed using the constructor * {@link QuestionSetManager#QuestionSetManager(String, long, double[])} * * @param duplicateQuestionTsvFilePath - The absolute file path of the duplicate question TSV file * to be parsed * @param trainTestValidateFractions - A three-element array representing the relative proportions * of the training, test, and validation subsets. For example, <code>[0.7, 0.2, 0.1]</code> * would indicate that (in expectation) 70% of the duplicate questions should belong to the * training set, 20% should belong to the test set, and 10% should belong to the validation * set * @throws PipelineException */ public QuestionSetManager(String duplicateQuestionTsvFilePath, double[] trainTestValidateFractions) throws PipelineException { this(duplicateQuestionTsvFilePath, 525600, trainTestValidateFractions); } /** * Create a new {@link QuestionSetManager} that will create training, test, and validation sets * from a TSV file containing duplicate {@link StackExchangeThread} questions, i.e., questions * that have been marked as duplicates of other {@link StackExchangeThread} questions * * @param duplicateQuestionTsvFilePath - The absolute file path of the duplicate question TSV file * to be parsed * @param seed - A seed for the random number generator used to provide a pseudo-random * partitioning of the data * @param trainTestValidateFractions - A three-element array representing the relative proportions * of the training, test, and validation subsets. For example, <code>[0.7, 0.2, 0.1]</code> * would indicate that (in expectation) 70% of the duplicate questions should belong to the * training set, 20% should belong to the test set, and 10% should belong to the validation * set * @throws PipelineException */ public QuestionSetManager(String duplicateQuestionTsvFilePath, long seed, double[] trainTestValidateFractions) throws PipelineException { // Begin by initializing member variables this.rng = new Random(seed); this.trainingSet = new QuestionAnswerSet(duplicateQuestionTsvFilePath); this.testSet = new QuestionAnswerSet(duplicateQuestionTsvFilePath); this.validationSet = new QuestionAnswerSet(duplicateQuestionTsvFilePath); // Next, normalize the train/test/validate fractions double[] cumulativeProbabilities = checkAndNormalizeFractions(trainTestValidateFractions); // Finally, parse the duplicate question TSV file, bucketing each entry // into // one of the three sets try { if (duplicateQuestionTsvFilePath == null) throw new IllegalArgumentException(Messages.getString("RetrieveAndRank.DUPLICATE_TSV_PATH")); //$NON-NLS-1$ File dupQuestionFile = new File(duplicateQuestionTsvFilePath); if (!dupQuestionFile.exists() || !dupQuestionFile.canRead()) throw new IOException( MessageFormat.format(Messages.getString("RetrieveAndRank.DUPLICATE_TSV_MISSING"), //$NON-NLS-1$ duplicateQuestionTsvFilePath)); parseTsvAndPartitionRecords(dupQuestionFile, cumulativeProbabilities); } catch (IOException | PipelineException e) { logger.fatal(e.getMessage()); throw new PipelineException(e); } } /** * Create a new {@link QuestionSetManager} that will create training, test, and validation sets * from a TSV file containing duplicate {@link StackExchangeThread} questions, i.e., questions * that have been marked as duplicates of other {@link StackExchangeThread} questions * * @param properties - A {@link Properties} object with the following properties: <br> * - {@link ConfigurationConstants#DUPLICATE_THREAD_TSV_PATH} [required] : The path to the * folder containing a duplicate thread TSV file<br> * - {@link ConfigurationConstants#QUESTION_SET_MANAGER_PARTITION_FRACTIONS} [required] : A * 3-element comma-separated array representing the relative proportions of the training, * test, and validation subsets. For example, <code>[0.7, 0.2, 0.1]</code> would indicate * that (in expectation) 70% of the duplicate questions should belong to the training set, * 20% should belong to the test set, and 10% should belong to the validation set<br> * - {@link ConfigurationConstants#QUESTION_SET_MANAGER_RAND_NUM_SEED} [optional] : A seed * for the random number generator used to provide a pseudo-random partitioning of the data * * @throws PipelineException */ public static QuestionSetManager newInstance(Properties properties) throws PipelineException { String duplicateQuestionTsvFilePath = properties .getProperty(ConfigurationConstants.DUPLICATE_THREAD_TSV_PATH); if (duplicateQuestionTsvFilePath == null) throw new PipelineException(MessageFormat.format(Messages.getString("RetrieveAndRank.MISSING_PROPERTY"), //$NON-NLS-1$ ConfigurationConstants.DUPLICATE_THREAD_TSV_PATH)); long seed = Long.valueOf( properties.getProperty(ConfigurationConstants.QUESTION_SET_MANAGER_RAND_NUM_SEED, "525600")); String fractions = properties.getProperty(ConfigurationConstants.QUESTION_SET_MANAGER_PARTITION_FRACTIONS); if (fractions == null) throw new PipelineException(MessageFormat.format(Messages.getString("RetrieveAndRank.MISSING_PROPERTY"), //$NON-NLS-1$ ConfigurationConstants.QUESTION_SET_MANAGER_PARTITION_FRACTIONS)); String[] fractionArray = fractions.replaceAll("(\\[|\\])", "").split(","); double[] trainTestValidateFractions = new double[fractionArray.length]; int i = 0; for (String fraction : fractionArray) { trainTestValidateFractions[i] = Double.valueOf(fraction.trim()); i++; } return new QuestionSetManager(duplicateQuestionTsvFilePath, seed, trainTestValidateFractions); } /** * This function ensures that the input array is a three-element, non-negative array, and returns * a three-element cumulative density function (CDF) * * @param trainTestValidateFractions - The relative fraction of questions that should be allocated * to train, test, and validation sets * * @return A three-element cumulative density function (CDF) representing the cumulative * probability that a given duplicate thread belongs to the training, test, or validation * set respectively */ private double[] checkAndNormalizeFractions(double[] trainTestValidateFractions) { // Validate input if (trainTestValidateFractions == null || trainTestValidateFractions.length != 3) throw new IllegalArgumentException(Messages.getString("RetrieveAndRank.TEST_FRACTIONS_THREE")); //$NON-NLS-1$ else { for (double f : trainTestValidateFractions) { if (f < 0) { throw new IllegalArgumentException( Messages.getString("RetrieveAndRank.TEST_FRACTIONS_NONNEGATIVE")); //$NON-NLS-1$ } } } // Tally up the sum of the three fractional values... double sum = 0; for (double f : trainTestValidateFractions) sum += f; // ...and normalize double[] cumulativeProbabilities = new double[trainTestValidateFractions.length]; for (int i = 0; i < trainTestValidateFractions.length; i++) { if (i == 0) cumulativeProbabilities[i] = trainTestValidateFractions[i] / sum; else cumulativeProbabilities[i] = cumulativeProbabilities[i - 1] + (trainTestValidateFractions[i] / sum); } return cumulativeProbabilities; } /** * This function is responsible for parsing a duplicate Stack Exchange thread TSV file produced by * {@link StackExchangeThreadSerializer}, and partitioning each such thread into the training set, * test set, or validation set. In addition, the corresponding row of the TSV file will be written * out to a training-, test-, or validation-set-specific TSV file in the same directory as the * input TSV file. * * @param dupQuestionFile - A TSV file containing duplicate {@link StackExchangeThread} records * @param trainTestValidateCumulativeProbs - A CDF of the desired proportion of training, test, * and validation set records * @throws PipelineException */ private void parseTsvAndPartitionRecords(File dupQuestionFile, double[] trainTestValidateCumulativeProbs) throws PipelineException { // Open the TSV file for parsing, and CSVPrinters for outputting train, // test, and validation set // TSV files String baseName = FilenameUtils.removeExtension(dupQuestionFile.getAbsolutePath()); String extension = FilenameUtils.getExtension(dupQuestionFile.getAbsolutePath()); try (FileReader reader = new FileReader(dupQuestionFile); CSVPrinter trainSetPrinter = new CSVPrinter( new FileWriter(baseName + StackExchangeConstants.DUP_THREAD_TSV_TRAIN_FILE_SUFFIX + FilenameUtils.EXTENSION_SEPARATOR + extension), CSVFormat.TDF.withHeader(CorpusBuilder.getTsvColumnHeaders())); CSVPrinter testSetPrinter = new CSVPrinter( new FileWriter(baseName + StackExchangeConstants.DUP_THREAD_TSV_TEST_FILE_SUFFIX + FilenameUtils.EXTENSION_SEPARATOR + extension), CSVFormat.TDF.withHeader(CorpusBuilder.getTsvColumnHeaders())); CSVPrinter validationSetPrinter = new CSVPrinter( new FileWriter(baseName + StackExchangeConstants.DUP_THREAD_TSV_VALIDATE_FILE_SUFFIX + FilenameUtils.EXTENSION_SEPARATOR + extension), CSVFormat.TDF.withHeader(CorpusBuilder.getTsvColumnHeaders()))) { // Parse the duplicate thread TSV file CSVParser parser = CSVFormat.TDF.withHeader().parse(reader); // Iterate over each CSV record, and place into a desired partition // (train, test, or // validation) Iterator<CSVRecord> recordIterator = parser.iterator(); while (recordIterator.hasNext()) { CSVRecord record = recordIterator.next(); // Get the StackExchangeThread associated with this record, and // create a question from it StackExchangeThread duplicateThread = StackExchangeThreadSerializer.deserializeThreadFromBinFile( record.get(CorpusBuilder.TSV_COL_HEADER_SERIALIZED_FILE_PATH)); StackExchangeQuestion duplicateQuestion = new StackExchangeQuestion(duplicateThread); String parentId = record.get(CorpusBuilder.TSV_COL_HEADER_PARENT_ID); // Now drop this question into a partition, and write it to a // corresponding TSV file double p = rng.nextDouble(); // Random number determines // partition for this record if (p <= trainTestValidateCumulativeProbs[0]) { // This record goes in the training set if (!addQuestionToSet(duplicateQuestion, parentId, this.trainingSet)) { throw new PipelineException( MessageFormat.format(Messages.getString("RetrieveAndRank.TRAINING_SET_FAILED_Q"), //$NON-NLS-1$ duplicateThread.getId())); } trainSetPrinter.printRecord((Object[]) convertRecordToArray(record)); } else if (p <= trainTestValidateCumulativeProbs[1]) { // This record goes in the test set if (!addQuestionToSet(duplicateQuestion, parentId, this.testSet)) { throw new PipelineException( MessageFormat.format(Messages.getString("RetrieveAndRank.TEST_SET_FAILED_Q"), //$NON-NLS-1$ duplicateThread.getId())); } testSetPrinter.printRecord((Object[]) convertRecordToArray(record)); } else { // This record goes in the validation set assert (p <= trainTestValidateCumulativeProbs[2]); if (!addQuestionToSet(duplicateQuestion, parentId, this.validationSet)) { throw new PipelineException( MessageFormat.format(Messages.getString("RetrieveAndRank.VALIDATION_SET_FAILED_Q"), //$NON-NLS-1$ duplicateThread.getId())); } validationSetPrinter.printRecord((Object[]) convertRecordToArray(record)); } } // Flush all the printers prior to closing trainSetPrinter.flush(); testSetPrinter.flush(); validationSetPrinter.flush(); } catch (IOException | IngestionException e) { throw new PipelineException(e); } } /** * @param record - A single {@link CSVRecord} from the duplicate thread TSV file * @return A string array representing the data in each column of the record */ private String[] convertRecordToArray(CSVRecord record) { String[] recordArray = new String[CorpusBuilder.getTsvColumnHeaders().length]; for (int i = 0; i < CorpusBuilder.getTsvColumnHeaders().length; i++) recordArray[i] = record.get(i); return recordArray; } /** * Adds a new {@link StackExchangeQuestion} to a chosen {@link QuestionAnswerSet} * * @param duplicateQuestion - A {@link StackExchangeQuestion} for the duplicate question * @param parentId - The ID of the thread that this question duplicates * @param set - The training, test, or validation {@link QuestionAnswerSet} to which this question * is being added * * @return <code>true</code> if this question was added to the set, <code>false</code> otherwise */ private boolean addQuestionToSet(StackExchangeQuestion duplicateQuestion, String parentId, QuestionAnswerSet set) { // Create a new correct answer to add to the ground truth for this // question. The // parent ID (i.e., the ID of the thread that this thread duplicates) is // the answer // text that will be used in comparisons to determine whether a // candidate answer // is correct or not CorrectAnswer c = new CorrectAnswer(parentId); // Now add this question, and its answer, to the QA set return set.addQuestionAnswers(duplicateQuestion, Arrays.asList(c), null); } /** * @return The {@link QuestionAnswerSet} training data set */ public QuestionAnswerSet getTrainingSet() { return this.trainingSet; } /** * @return The {@link QuestionAnswerSet} test data set */ public QuestionAnswerSet getTestSet() { return this.testSet; } /** * @return The {@link QuestionAnswerSet} validation data set */ public QuestionAnswerSet getValidationSet() { return this.validationSet; } }