TrainLogistic.java Source code

Java tutorial

Introduction

Here is the source code for TrainLogistic.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

//package org.apache.mahout.classifier.sgd;

import com.google.common.base.Charsets;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import com.google.common.io.Resources;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.util.HelpFormatter;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.classifier.sgd.*;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FilenameFilter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.List;
import java.util.Locale;

/**
 * Train a logistic regression for the examples from Chapter 13 of Mahout in Action
 */
public final class TrainLogistic {

    private static String inputFile;
    private static String outputFile;
    private static LogisticModelParameters lmp;
    private static int passes;
    private static boolean scores;
    private static OnlineLogisticRegression model;

    private TrainLogistic() {
    }

    public static void main(String[] args) throws Exception {
        long start = System.currentTimeMillis();
        mainToOutput(args, new PrintWriter(System.out, true));
        System.out.printf("time:%.2f s\n", (double) (System.currentTimeMillis() - start) / 1000.0);
    }

    static void mainToOutput(String[] args, PrintWriter output) throws Exception {
        if (parseArgs(args)) {
            double logPEstimate = 0;
            int samples = 0;
            /*read files in dir of inputFile*/
            int fi = 0;//file ID
            File file = new File(inputFile);
            String[] fns = file.list(new FilenameFilter() {
                public boolean accept(File dir, String name) {
                    if (name.endsWith(".svm")) {
                        return true;
                    } else {
                        return false;
                    }
                }
            });

            String[] ss = new String[lmp.getNumFeatures() + 1];
            String[] iv = new String[2];
            OnlineLogisticRegression lr = lmp.createRegression();
            while (fi < fns.length) {
                for (int pass = 0; pass < passes; pass++) {
                    BufferedReader in = open(inputFile + fns[fi]);
                    System.out.println(pass + 1);
                    try {
                        // read variable names

                        String line = in.readLine();
                        int lineCount = 1;
                        while (line != null) {
                            // for each new line, get target and predictors
                            Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                            ss = line.split(" ");
                            int targetValue;
                            if (ss[0].startsWith("+"))
                                targetValue = 1;
                            else
                                targetValue = 0;
                            int k = 1;
                            while (k < ss.length) {
                                iv = ss[k].split(":");
                                input.setQuick(Integer.valueOf(iv[0]) - 1, Double.valueOf(iv[1]));
                                //System.out.printf("%d-----%d:%.4f====%d\n", k,Integer.valueOf(iv[0])-1,Double.valueOf(iv[1]),lineCount);
                                k++;
                            }
                            input.setQuick(lmp.getNumFeatures() - 1, 1);
                            // check performance while this is still news
                            double logP = lr.logLikelihood(targetValue, input);
                            if (!Double.isInfinite(logP)) {
                                if (samples < 20) {
                                    logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
                                } else {
                                    logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
                                }
                                samples++;
                            }
                            double p = lr.classifyScalar(input);
                            if (scores) {
                                output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f\n", samples,
                                        targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
                            }
                            // now update model
                            lr.train(targetValue, input);
                            if ((lineCount) % 1000 == 0)
                                System.out.printf("%d\t", lineCount);
                            line = in.readLine();
                            lineCount++;
                        }
                    } finally {
                        Closeables.closeQuietly(in);
                    }
                    System.out.println();
                }
                fi++;
            }

            FileOutputStream modelOutput = new FileOutputStream(outputFile);
            try {
                saveTo(modelOutput, lr);
            } finally {
                Closeables.closeQuietly(modelOutput);
            }
            /*
                  output.printf(Locale.ENGLISH, "%d\n", lmp.getNumFeatures());
                  output.printf(Locale.ENGLISH, "%s ~ ", lmp.getTargetVariable());
                  String sep = "";
                  for (String v : csv.getTraceDictionary().keySet()) {
                    double weight = predictorWeight(lr, 0, csv, v);
                    if (weight != 0) {
                      output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
                      sep = " + ";
                    }
                  }
                  output.printf("\n");
                  model = lr;
                  for (int row = 0; row < lr.getBeta().numRows(); row++) {
                    for (String key : csv.getTraceDictionary().keySet()) {
                      double weight = predictorWeight(lr, row, csv, key);
                      if (weight != 0) {
                        output.printf(Locale.ENGLISH, "%20s %.5f\n", key, weight);
                      }
                    }
                    for (int column = 0; column < lr.getBeta().numCols(); column++) {
                      output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
                    }
                    output.println();
                  }*/
        }
    }

    private static void saveTo(FileOutputStream modelOutput, OnlineLogisticRegression lr) {
        PrintWriter w = new PrintWriter(new OutputStreamWriter(modelOutput));
        String str = new String(" ");
        //System.out.printf("%d columns\n",lr.getBeta().numCols());
        System.out.printf("Now, writing file...\n");
        for (int column = 0; column < lr.getBeta().numCols(); column++) {
            //System.out.printf("%f, ", lr.getBeta().get(0, column));
            str = java.lang.String.format("%f\n", lr.getBeta().get(0, column));
            w.write(str);
            w.flush();
        }
        w.close();
    }

    private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv,
            String predictor) {
        double weight = 0;
        for (Integer column : csv.getTraceDictionary().get(predictor)) {
            weight += lr.getBeta().get(row, column);
        }
        return weight;
    }

    private static boolean parseArgs(String[] args) {
        DefaultOptionBuilder builder = new DefaultOptionBuilder();

        Option help = builder.withLongName("help").withDescription("print this list").create();

        Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create();
        Option scores = builder.withLongName("scores").withDescription("output score diagnostics during training")
                .create();

        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        Option inputFile = builder.withLongName("input").withRequired(true)
                .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
                .withDescription("where to get training data").create();

        Option outputFile = builder.withLongName("output").withRequired(true)
                .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
                .withDescription("where to get training data").create();

        Option predictors = builder.withLongName("predictors").withRequired(true)
                .withArgument(argumentBuilder.withName("p").create())
                .withDescription("a list of predictor variables").create();

        Option types = builder.withLongName("types").withRequired(true)
                .withArgument(argumentBuilder.withName("t").create())
                .withDescription("a list of predictor variable types (numeric, word, or text)").create();

        Option target = builder.withLongName("target").withRequired(true)
                .withArgument(argumentBuilder.withName("target").withMaximum(1).create())
                .withDescription("the name of the target variable").create();

        Option features = builder.withLongName("features")
                .withArgument(argumentBuilder.withName("numFeatures").withDefault("1000").withMaximum(1).create())
                .withDescription("the number of internal hashed features to use").create();

        Option passes = builder.withLongName("passes")
                .withArgument(argumentBuilder.withName("passes").withDefault("2").withMaximum(1).create())
                .withDescription("the number of times to pass over the input data").create();

        Option lambda = builder.withLongName("lambda")
                .withArgument(argumentBuilder.withName("lambda").withDefault("1e-4").withMaximum(1).create())
                .withDescription("the amount of coefficient decay to use").create();

        Option rate = builder.withLongName("rate")
                .withArgument(argumentBuilder.withName("learningRate").withDefault("1e-3").withMaximum(1).create())
                .withDescription("the learning rate").create();

        Option noBias = builder.withLongName("noBias").withDescription("don't include a bias term").create();

        Option targetCategories = builder.withLongName("categories").withRequired(true)
                .withArgument(argumentBuilder.withName("number").withMaximum(1).create())
                .withDescription("the number of target categories to be considered").create();

        Group normalArgs = new GroupBuilder().withOption(help).withOption(quiet).withOption(inputFile)
                .withOption(outputFile).withOption(target).withOption(targetCategories).withOption(predictors)
                .withOption(types).withOption(passes).withOption(lambda).withOption(rate).withOption(noBias)
                .withOption(features).create();

        Parser parser = new Parser();
        parser.setHelpOption(help);
        parser.setHelpTrigger("--help");
        parser.setGroup(normalArgs);
        parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
        CommandLine cmdLine = parser.parseAndHelp(args);

        if (cmdLine == null) {
            System.out.println(args);
            return false;
        }

        TrainLogistic.inputFile = getStringArgument(cmdLine, inputFile);
        TrainLogistic.outputFile = getStringArgument(cmdLine, outputFile);

        List<String> typeList = Lists.newArrayList();
        for (Object x : cmdLine.getValues(types)) {
            typeList.add(x.toString());
        }

        List<String> predictorList = Lists.newArrayList();
        for (Object x : cmdLine.getValues(predictors)) {
            predictorList.add(x.toString());
        }

        lmp = new LogisticModelParameters();
        lmp.setTargetVariable(getStringArgument(cmdLine, target));
        lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories));
        lmp.setNumFeatures(getIntegerArgument(cmdLine, features));
        lmp.setUseBias(!getBooleanArgument(cmdLine, noBias));
        lmp.setTypeMap(predictorList, typeList);

        lmp.setLambda(getDoubleArgument(cmdLine, lambda));
        lmp.setLearningRate(getDoubleArgument(cmdLine, rate));

        TrainLogistic.scores = getBooleanArgument(cmdLine, scores);
        TrainLogistic.passes = getIntegerArgument(cmdLine, passes);

        return true;
    }

    private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
        return (String) cmdLine.getValue(inputFile);
    }

    private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
        return cmdLine.hasOption(option);
    }

    private static int getIntegerArgument(CommandLine cmdLine, Option features) {
        return Integer.parseInt((String) cmdLine.getValue(features));
    }

    private static double getDoubleArgument(CommandLine cmdLine, Option op) {
        return Double.parseDouble((String) cmdLine.getValue(op));
    }

    public static OnlineLogisticRegression getModel() {
        return model;
    }

    public static LogisticModelParameters getParameters() {
        return lmp;
    }

    static BufferedReader open(String inputFile) throws IOException {
        InputStream in;
        try {
            in = Resources.getResource(inputFile).openStream();
        } catch (IllegalArgumentException e) {
            in = new FileInputStream(new File(inputFile));
        }
        return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8));
    }
}