com.tamingtext.classifier.mlt.TrainMoreLikeThis.java Source code

Java tutorial

Introduction

Here is the source code for com.tamingtext.classifier.mlt.TrainMoreLikeThis.java

Source

/*
 * Copyright 2008-2011 Grant Ingersoll, Thomas Morton and Drew Farris
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 * -------------------
 * To purchase or learn more about Taming Text, by Grant Ingersoll, Thomas Morton and Drew Farris, visit
 * http://www.manning.com/ingersoll
 */

package com.tamingtext.classifier.mlt;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.en.EnglishAnalyzer;
import org.apache.lucene.analysis.shingle.ShingleAnalyzerWrapper;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.IndexWriterConfig.OpenMode;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.Version;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.tamingtext.util.FileUtil;

public class TrainMoreLikeThis {

    private static final Logger log = LoggerFactory.getLogger(TrainMoreLikeThis.class);

    public static final String CATEGORY_KEY = "categories";

    public static enum MatchMode {
        KNN, TFIDF
    }

    private IndexWriter writer;

    private int nGramSize = 1;

    public TrainMoreLikeThis() {
    }

    public void setNGramSize(int nGramSize) {
        this.nGramSize = nGramSize;
    }

    public void train(String source, String destination, MatchMode mode) throws Exception {

        File[] inputFiles = FileUtil.buildFileList(new File(source));

        if (inputFiles.length < 2) {
            throw new IllegalStateException("There must be more than one training file in " + source);
        }

        openIndexWriter(destination);

        switch (mode) {
        case TFIDF:
            this.buildTfidfIndex(inputFiles);
            break;
        case KNN:
            this.buildKnnIndex(inputFiles);
            break;
        default:
            throw new IllegalStateException("Unknown match mode: " + mode.toString());
        }

        closeIndexWriter();
    }

    /** builda a lucene index suidable for knn based classification. Each category's content is indexed into
     *  separate documents in the index, and the category that has the haghest count in the tip N hits is 
     *  is the category that is assigned.
     * @param inputFiles
     * @param writer
     * @throws Exception
     */
    protected void buildKnnIndex(File[] inputFiles) throws Exception {
        int lineCount = 0;
        int fileCount = 0;
        String line = null;
        String category = null;
        Set<String> categories = new HashSet<String>();
        long start = System.currentTimeMillis();

        // reuse these fields
        //<start id="lucene.examples.fields"/>
        Field id = new Field("id", "", Field.Store.YES, Field.Index.NOT_ANALYZED, Field.TermVector.NO);
        Field categoryField = new Field("category", "", Field.Store.YES, Field.Index.NOT_ANALYZED,
                Field.TermVector.NO);
        Field contentField = new Field("content", "", Field.Store.NO, Field.Index.ANALYZED,
                Field.TermVector.WITH_POSITIONS_OFFSETS);
        //<end id="lucene.examples.fields"/>

        for (File ff : inputFiles) {
            fileCount++;
            lineCount = 0;
            category = null;

            BufferedReader in = new BufferedReader(new FileReader(ff));

            //<start id="lucene.examples.knn.train"/>
            while ((line = in.readLine()) != null) {
                String[] parts = line.split("\t"); //<co id="luc.knn.content"/>
                if (parts.length != 2)
                    continue;
                category = parts[0];
                categories.add(category);

                Document d = new Document(); //<co id="luc.knn.document"/>
                id.setValue(category + "-" + lineCount++);
                categoryField.setValue(category);
                contentField.setValue(parts[1]);
                d.add(id);
                d.add(categoryField);
                d.add(contentField);

                writer.addDocument(d); //<co id="luc.knn.index"/>
            }
            /*<calloutlist>
            <callout arearefs="luc.knn.content">Collect Content</callout>
            <callout arearefs="luc.knn.document">Build Document</callout>
            <callout arearefs="luc.knn.index">Index Document</callout>
            </calloutlist>*/
            //<end id="lucene.examples.knn.train"/>

            in.close();

            log.info("Knn: Added document for category " + category + " with " + lineCount + " lines");
        }

        writer.commit(generateUserData(categories));

        log.info("Knn: Added " + fileCount + " categories in " + (System.currentTimeMillis() - start) + " msec.");
    }

    /** builds a lucene index suitable for tfidf based classification. Each categories content is indexed into
     *  a single document in the index, and the best match for a MoreLikeThis query is the category that
     *  is assigned.
     * @param inputFiles
     * @param writer
     * @throws Exception
     */
    protected void buildTfidfIndex(File[] inputFiles) throws Exception {

        int lineCount = 0;
        int fileCount = 0;
        String line = null;
        Set<String> categories = new HashSet<String>();
        long start = System.currentTimeMillis();

        // reuse these fields
        Field id = new Field("id", "", Field.Store.YES, Field.Index.NOT_ANALYZED, Field.TermVector.NO);
        Field categoryField = new Field("category", "", Field.Store.YES, Field.Index.NOT_ANALYZED,
                Field.TermVector.NO);
        Field contentField = new Field("content", "", Field.Store.NO, Field.Index.ANALYZED,
                Field.TermVector.WITH_POSITIONS_OFFSETS);

        // read data from input files.

        for (File ff : inputFiles) {
            fileCount++;
            lineCount = 0;

            // read all training documents into a string
            BufferedReader in = new BufferedReader(new InputStreamReader(new FileInputStream(ff), "UTF-8"));

            //<start id="lucene.examples.tfidf.train"/>
            StringBuilder content = new StringBuilder();
            String category = null;
            while ((line = in.readLine()) != null) {
                String[] parts = line.split("\t"); //<co id="luc.tf.content"/>
                if (parts.length != 2)
                    continue;
                category = parts[0];
                categories.add(category);
                content.append(parts[1]).append(" ");
                lineCount++;
            }

            in.close();

            Document d = new Document(); //<co id="luc.tf.document"/>
            id.setValue(category + "-" + lineCount);
            categoryField.setValue(category);
            contentField.setValue(content.toString());
            d.add(id);
            d.add(categoryField);
            d.add(contentField);

            writer.addDocument(d); //<co id="luc.tf.index"/>
            /*<calloutlist>
              <callout arearefs="luc.tf.content">Collect Content</callout>
              <callout arearefs="luc.tf.document">Build Document</callout>
              <callout arearefs="luc.tf.index">Index Document</callout>
             </calloutlist>*/
            //<end id="lucene.examples.tfidf.train"/>

            log.info("TfIdf: Added document for category " + category + " with " + lineCount + " lines");
        }

        writer.commit(generateUserData(categories));

        log.info("TfIdf: Added " + fileCount + " categories in " + (System.currentTimeMillis() - start) + " msec.");
    }

    protected void openIndexWriter(String pathname) throws IOException {
        //<start id="lucene.examples.index.setup"/>
        Directory directory //<co id="luc.index.dir"/>
                = FSDirectory.open(new File(pathname));
        Analyzer analyzer //<co id="luc.index.analyzer"/>
                = new EnglishAnalyzer(Version.LUCENE_36);

        if (nGramSize > 1) { //<co id="luc.index.shingle"/>
            ShingleAnalyzerWrapper sw = new ShingleAnalyzerWrapper(analyzer, nGramSize, // min shingle size
                    nGramSize, // max shingle size
                    "-", // token separator
                    true, // output unigrams
                    true); // output unigrams if no shingles
            analyzer = sw;
        }

        IndexWriterConfig config //<co id="luc.index.create"/>
                = new IndexWriterConfig(Version.LUCENE_36, analyzer);
        config.setOpenMode(OpenMode.CREATE);
        IndexWriter writer = new IndexWriter(directory, config);
        /* <calloutlist>
        <callout arearefs="luc.index.dir">Create Index Directory</callout>
        <callout arearefs="luc.index.analyzer">Setup Analyzer</callout>
        <callout arearefs="luc.index.shingle">Setup Shingle Filter</callout>
        <callout arearefs="luc.index.create">Create <classname>IndexWriter</classname></callout>
        </calloutlist> */
        //<end id="lucene.examples.index.setup"/>
        this.writer = writer;
    }

    protected void closeIndexWriter() throws IOException {
        log.info("Starting optimize");

        // optimize and close the index.
        writer.optimize();
        writer.close();
        writer = null;

        log.info("Optimize complete, index closed");
    }

    protected static Map<String, String> generateUserData(Collection<String> categories) {
        StringBuilder b = new StringBuilder();
        for (String cat : categories) {
            b.append(cat).append('|');
        }
        b.setLength(b.length() - 1);
        Map<String, String> userData = new HashMap<String, String>();
        userData.put(CATEGORY_KEY, b.toString());
        return userData;
    }

    public static void main(String[] args) throws Exception {
        DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
        ArgumentBuilder abuilder = new ArgumentBuilder();
        GroupBuilder gbuilder = new GroupBuilder();

        Option helpOpt = DefaultOptionCreator.helpOption();

        Option inputDirOpt = obuilder.withLongName("input").withRequired(true)
                .withArgument(abuilder.withName("input").withMinimum(1).withMaximum(1).create())
                .withDescription("The input directory, containing properly formatted files: "
                        + "One doc per line, first entry on the line is the label, rest is the evidence")
                .withShortName("i").create();

        Option outputOpt = obuilder.withLongName("output").withRequired(true)
                .withArgument(abuilder.withName("output").withMinimum(1).withMaximum(1).create())
                .withDescription("The output directory").withShortName("o").create();

        Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(false)
                .withArgument(abuilder.withName("gramSize").withMinimum(1).withMaximum(1).create())
                .withDescription("Size of the n-gram. Default Value: 1 ").withShortName("ng").create();

        Option typeOpt = obuilder.withLongName("classifierType").withRequired(false)
                .withArgument(abuilder.withName("classifierType").withMinimum(1).withMaximum(1).create())
                .withDescription("Type of classifier: knn|tfidf.").withShortName("type").create();

        Group group = gbuilder.withName("Options").withOption(gramSizeOpt).withOption(helpOpt)
                .withOption(inputDirOpt).withOption(outputOpt).withOption(typeOpt).create();

        try {
            Parser parser = new Parser();

            parser.setGroup(group);
            parser.setHelpOption(helpOpt);
            CommandLine cmdLine = parser.parse(args);
            if (cmdLine.hasOption(helpOpt)) {
                CommandLineUtil.printHelp(group);
                return;
            }

            String classifierType = (String) cmdLine.getValue(typeOpt);

            int gramSize = 1;
            if (cmdLine.hasOption(gramSizeOpt)) {
                gramSize = Integer.parseInt((String) cmdLine.getValue(gramSizeOpt));
            }

            String inputPath = (String) cmdLine.getValue(inputDirOpt);
            String outputPath = (String) cmdLine.getValue(outputOpt);
            TrainMoreLikeThis trainer = new TrainMoreLikeThis();
            MatchMode mode;

            if ("knn".equalsIgnoreCase(classifierType)) {
                mode = MatchMode.KNN;
            } else if ("tfidf".equalsIgnoreCase(classifierType)) {
                mode = MatchMode.TFIDF;
            } else {
                throw new IllegalArgumentException("Unkown classifierType: " + classifierType);
            }

            if (gramSize > 1)
                trainer.setNGramSize(gramSize);

            trainer.train(inputPath, outputPath, mode);

        } catch (OptionException e) {
            log.error("Error while parsing options", e);
        }
    }
}