com.tdunning.ch16.train.TrainNewsGroups.java Source code

Java tutorial

Introduction

Here is the source code for com.tdunning.ch16.train.TrainNewsGroups.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.tdunning.ch16.train;

import com.google.common.base.Charsets;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import com.google.common.io.Files;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.util.Version;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.classifier.sgd.CrossFoldLearner;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.ModelDissector;
import org.apache.mahout.classifier.sgd.ModelSerializer;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
import org.apache.mahout.vectorizer.encoders.TextValueEncoder;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.text.SimpleDateFormat;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.Set;

/**
 * Reads and trains an adaptive logistic regression model on the 20 newsgroups data.
 * The first command line argument gives the path of the directory holding the training
 * data.  The optional second argument, leakType, defines which classes of features to use.
 * Importantly, leakType controls whether a synthetic date is injected into the data as
 * a target leak and if so, how.
 * <p>
 * The value of leakType % 3 determines whether the target leak is injected according to
 * the following table:
 * <p>
 * <table>
 * <tr><td valign='top'>0</td><td>No leak injected</td></tr>
 * <tr><td valign='top'>1</td><td>Synthetic date injected in MMM-yyyy format. This will be a single token and
 * is a perfect target leak since each newsgroup is given a different month</td></tr>
 * <tr><td valign='top'>2</td><td>Synthetic date injected in dd-MMM-yyyy HH:mm:ss format.  The day varies
 * and thus there are more leak symbols that need to be learned.  Ultimately this is just
 * as big a leak as case 1.</td></tr>
 * </table>
 * <p>
 * Leaktype also determines what other text will be indexed.  If leakType is greater
 * than or equal to 6, then neither headers nor text body will be used for features and the leak is the only
 * source of data.  If leakType is greater than or equal to 3, then subject words will be used as features.
 * If leakType is less than 3, then both subject and body text will be used as features.
 * <p>
 * A leakType of 0 gives no leak and all textual features.
 * <p>
 * See the following table for a summary of commonly used values for leakType
 * <p>
 * <table>
 * <tr><td><b>leakType</b></td><td><b>Leak?</b></td><td><b>Subject?</b></td><td><b>Body?</b></td></tr>
 * <tr><td colspan=4><hr></td></tr>
 * <tr><td>0</td><td>no</td><td>yes</td><td>yes</td></tr>
 * <tr><td>1</td><td>mmm-yyyy</td><td>yes</td><td>yes</td></tr>
 * <tr><td>2</td><td>dd-mmm-yyyy</td><td>yes</td><td>yes</td></tr>
 * <tr><td colspan=4><hr></td></tr>
 * <tr><td>3</td><td>no</td><td>yes</td><td>no</td></tr>
 * <tr><td>4</td><td>mmm-yyyy</td><td>yes</td><td>no</td></tr>
 * <tr><td>5</td><td>dd-mmm-yyyy</td><td>yes</td><td>no</td></tr>
 * <tr><td colspan=4><hr></td></tr>
 * <tr><td>6</td><td>no</td><td>no</td><td>no</td></tr>
 * <tr><td>7</td><td>mmm-yyyy</td><td>no</td><td>no</td></tr>
 * <tr><td>8</td><td>dd-mmm-yyyy</td><td>no</td><td>no</td></tr>
 * <tr><td colspan=4><hr></td></tr>
 * </table>
 */
public final class TrainNewsGroups {
    private static final int FEATURES = 10000;

    // 1997-01-15 00:01:00 GMT
    private static final long DATE_REFERENCE = 853286460;
    private static final long MONTH = 30 * 24 * 3600;
    private static final long WEEK = 7 * 24 * 3600;

    private static final Random rand = RandomUtils.getRandom();

    private static final String[] LEAK_LABELS = { "none", "month-year", "day-month-year" };
    private static final SimpleDateFormat[] DATE_FORMATS = { new SimpleDateFormat("", Locale.ENGLISH),
            new SimpleDateFormat("MMM-yyyy", Locale.ENGLISH),
            new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss", Locale.ENGLISH) };

    private static final Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_30);
    private static final TextValueEncoder encoder = new TextValueEncoder("body");
    private static final FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");

    private TrainNewsGroups() {
    }

    public static void main(String[] args) throws IOException {
        File base = new File(args[0]);

        int leakType = 0;
        if (args.length > 1) {
            leakType = Integer.parseInt(args[1]);
        }

        Dictionary newsGroups = new Dictionary();

        encoder.setProbes(2);
        AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, FEATURES, new L1());
        learningAlgorithm.setInterval(800);
        learningAlgorithm.setAveragingWindow(500);

        List<File> files = Lists.newArrayList();
        File[] directories = base.listFiles();
        Arrays.sort(directories, Ordering.usingToString());
        for (File newsgroup : directories) {
            if (newsgroup.isDirectory()) {
                newsGroups.intern(newsgroup.getName());
                files.addAll(Arrays.asList(newsgroup.listFiles()));
            }
        }
        Collections.shuffle(files);
        System.out.printf("%d training files\n", files.size());
        System.out.printf("%s\n", Arrays.asList(directories));

        double averageLL = 0;
        double averageCorrect = 0;

        int k = 0;
        double step = 0;
        int[] bumps = { 1, 2, 5 };
        for (File file : files) {
            String ng = file.getParentFile().getName();
            int actual = newsGroups.intern(ng);

            Vector v = encodeFeatureVector(file);
            learningAlgorithm.train(actual, v);

            k++;

            int bump = bumps[(int) Math.floor(step) % bumps.length];
            int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
            State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
            double maxBeta;
            double nonZeros;
            double positive;
            double norm;

            double lambda = 0;
            double mu = 0;

            if (best != null) {
                CrossFoldLearner state = best.getPayload().getLearner();
                averageCorrect = state.percentCorrect();
                averageLL = state.logLikelihood();

                OnlineLogisticRegression model = state.getModels().get(0);
                // finish off pending regularization
                model.close();

                Matrix beta = model.getBeta();
                maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
                nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
                    @Override
                    public double apply(double v) {
                        return Math.abs(v) > 1.0e-6 ? 1 : 0;
                    }
                });
                positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
                    @Override
                    public double apply(double v) {
                        return v > 0 ? 1 : 0;
                    }
                });
                norm = beta.aggregate(Functions.PLUS, Functions.ABS);

                lambda = learningAlgorithm.getBest().getMappedParams()[0];
                mu = learningAlgorithm.getBest().getMappedParams()[1];
            } else {
                maxBeta = 0;
                nonZeros = 0;
                positive = 0;
                norm = 0;
            }
            if (k % (bump * scale) == 0) {
                if (learningAlgorithm.getBest() != null) {
                    ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
                            learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
                }

                step += 0.25;
                System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda,
                        mu);
                System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, averageLL, averageCorrect * 100,
                        LEAK_LABELS[leakType % 3]);
            }
        }
        learningAlgorithm.close();
        dissect(newsGroups, learningAlgorithm, files);
        System.out.println("exiting main");

        ModelSerializer.writeBinary("/tmp/news-group.model",
                learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
    }

    private static void dissect(Dictionary newsGroups, AdaptiveLogisticRegression learningAlgorithm,
            Iterable<File> files) throws IOException {
        CrossFoldLearner model = learningAlgorithm.getBest().getPayload().getLearner();
        model.close();

        Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
        ModelDissector md = new ModelDissector();

        encoder.setTraceDictionary(traceDictionary);
        bias.setTraceDictionary(traceDictionary);

        for (File file : permute(files, rand).subList(0, 500)) {
            traceDictionary.clear();
            Vector v = encodeFeatureVector(file);
            md.update(v, traceDictionary, model);
        }

        List<String> ngNames = Lists.newArrayList(newsGroups.values());
        List<ModelDissector.Weight> weights = md.summary(100);
        for (ModelDissector.Weight w : weights) {
            System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n", w.getFeature(), w.getWeight(),
                    ngNames.get(w.getMaxImpact() + 1), w.getCategory(1), w.getWeight(1), w.getCategory(2),
                    w.getWeight(2));
        }
    }

    private static Vector encodeFeatureVector(File file) throws IOException {
        BufferedReader reader = Files.newReader(file, Charsets.UTF_8);
        try {
            String line = reader.readLine();
            while (line != null && line.length() > 0) {
                boolean countHeader = (line.startsWith("From:") || line.startsWith("Subject:")
                        || line.startsWith("Keywords:") || line.startsWith("Summary:"));
                do {
                    if (countHeader) {
                        line = line.replaceAll(".*:", "");
                        encoder.addText(line.toLowerCase());
                    }
                    line = reader.readLine();
                } while (line != null && line.startsWith(" "));
            }

            if (line != null) {
                line = reader.readLine();
            }

            while (line != null) {
                encoder.addText(line.toLowerCase());
                line = reader.readLine();
            }
        } finally {
            reader.close();
        }

        Vector v = new RandomAccessSparseVector(FEATURES);
        bias.addToVector((byte[]) null, 1, v);
        encoder.flush(1, v);
        return v;
    }

    private static List<File> permute(Iterable<File> files, Random rand) {
        List<File> r = Lists.newArrayList();
        for (File file : files) {
            int i = rand.nextInt(r.size() + 1);
            if (i == r.size()) {
                r.add(file);
            } else {
                r.add(r.get(i));
                r.set(i, file);
            }
        }
        return r;
    }
}