Example usage for org.apache.mahout.classifier.sgd AdaptiveLogisticRegression train

List of usage examples for org.apache.mahout.classifier.sgd AdaptiveLogisticRegression train

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd AdaptiveLogisticRegression train.

Prototype

@Override
    public void train(int actual, Vector instance) 

Source Link

Usage

From source file:com.memonews.mahout.sentiment.SentimentModelTrainer.java

License:Apache License

public static void main(final String[] args) throws IOException {
    final File base = new File(args[0]);
    final String modelPath = args.length > 1 ? args[1] : "target/model";

    final Multiset<String> overallCounts = HashMultiset.create();

    final Dictionary newsGroups = new Dictionary();

    final SentimentModelHelper helper = new SentimentModelHelper();
    helper.getEncoder().setProbes(2);//from   www .j  a  v a2s  . c  om
    final AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(2,
            SentimentModelHelper.FEATURES, new L1());
    learningAlgorithm.setInterval(800);
    learningAlgorithm.setAveragingWindow(500);

    final List<File> files = Lists.newArrayList();
    for (final File newsgroup : base.listFiles()) {
        if (newsgroup.isDirectory()) {
            newsGroups.intern(newsgroup.getName());
            files.addAll(Arrays.asList(newsgroup.listFiles()));
        }
    }
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());
    final SGDInfo info = new SGDInfo();

    int k = 0;

    for (final File file : files) {
        final String ng = file.getParentFile().getName();
        final int actual = newsGroups.intern(ng);

        final Vector v = helper.encodeFeatureVector(file, overallCounts);
        learningAlgorithm.train(actual, v);

        k++;
        final State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();

        SGDHelper.analyzeState(info, 0, k, best);
    }
    learningAlgorithm.close();
    SGDHelper.dissect(0, newsGroups, learningAlgorithm, files, overallCounts);
    System.out.println("exiting main");

    ModelSerializer.writeBinary(modelPath,
            learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));

    final List<Integer> counts = Lists.newArrayList();
    System.out.printf("Word counts\n");
    for (final String count : overallCounts.elementSet()) {
        counts.add(overallCounts.count(count));
    }
    Collections.sort(counts, Ordering.natural().reverse());
    k = 0;
    for (final Integer count : counts) {
        System.out.printf("%d\t%d\n", k, count);
        k++;
        if (k > 1000) {
            break;
        }
    }
}

From source file:com.tdunning.ch16.train.TrainNewsGroups.java

License:Apache License

public static void main(String[] args) throws IOException {
    File base = new File(args[0]);

    int leakType = 0;
    if (args.length > 1) {
        leakType = Integer.parseInt(args[1]);
    }/* ww  w  . j  a  va2s. c  om*/

    Dictionary newsGroups = new Dictionary();

    encoder.setProbes(2);
    AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, FEATURES, new L1());
    learningAlgorithm.setInterval(800);
    learningAlgorithm.setAveragingWindow(500);

    List<File> files = Lists.newArrayList();
    File[] directories = base.listFiles();
    Arrays.sort(directories, Ordering.usingToString());
    for (File newsgroup : directories) {
        if (newsgroup.isDirectory()) {
            newsGroups.intern(newsgroup.getName());
            files.addAll(Arrays.asList(newsgroup.listFiles()));
        }
    }
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());
    System.out.printf("%s\n", Arrays.asList(directories));

    double averageLL = 0;
    double averageCorrect = 0;

    int k = 0;
    double step = 0;
    int[] bumps = { 1, 2, 5 };
    for (File file : files) {
        String ng = file.getParentFile().getName();
        int actual = newsGroups.intern(ng);

        Vector v = encodeFeatureVector(file);
        learningAlgorithm.train(actual, v);

        k++;

        int bump = bumps[(int) Math.floor(step) % bumps.length];
        int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
        State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
        double maxBeta;
        double nonZeros;
        double positive;
        double norm;

        double lambda = 0;
        double mu = 0;

        if (best != null) {
            CrossFoldLearner state = best.getPayload().getLearner();
            averageCorrect = state.percentCorrect();
            averageLL = state.logLikelihood();

            OnlineLogisticRegression model = state.getModels().get(0);
            // finish off pending regularization
            model.close();

            Matrix beta = model.getBeta();
            maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
            nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
                @Override
                public double apply(double v) {
                    return Math.abs(v) > 1.0e-6 ? 1 : 0;
                }
            });
            positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
                @Override
                public double apply(double v) {
                    return v > 0 ? 1 : 0;
                }
            });
            norm = beta.aggregate(Functions.PLUS, Functions.ABS);

            lambda = learningAlgorithm.getBest().getMappedParams()[0];
            mu = learningAlgorithm.getBest().getMappedParams()[1];
        } else {
            maxBeta = 0;
            nonZeros = 0;
            positive = 0;
            norm = 0;
        }
        if (k % (bump * scale) == 0) {
            if (learningAlgorithm.getBest() != null) {
                ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
                        learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
            }

            step += 0.25;
            System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda,
                    mu);
            System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, averageLL, averageCorrect * 100,
                    LEAK_LABELS[leakType % 3]);
        }
    }
    learningAlgorithm.close();
    dissect(newsGroups, learningAlgorithm, files);
    System.out.println("exiting main");

    ModelSerializer.writeBinary("/tmp/news-group.model",
            learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
}

From source file:gov.llnl.ontology.mains.TrainLogisticRegression.java

License:Open Source License

public static void main(String[] args) throws Exception {
    MRArgOptions options = new MRArgOptions();
    options.addOption('w', "wordnetDir", "Specifies the wordnet directory", true, "PATH", "Required");
    options.addOption('b', "basisMapping", "Specifies a serialzied basis mapping", true, "FILE", "Required");
    options.addOption('n', "numPasses", "Specifies the number of training passes to make. " + "(Default: 5)",
            true, "INT", "Optional");
    options.parseOptions(args);/*from  w w w . j a  v a  2 s.c o m*/

    if (options.numPositionalArgs() != 1 || !options.hasOption('w') || !options.hasOption('b')) {
        System.out.println("usage: java TrainLogisticRegression [OPTIONS] <out>\n" + options.prettyPrint());
        System.exit(1);
    }

    EvidenceTable table = options.evidenceTable();
    Scan scan = new Scan();
    table.setupScan(scan, options.sourceCorpus());

    StringBasisMapping basis = SerializableUtil.load(new File(options.getStringOption('b')));
    basis.setReadOnly(true);
    int numDimensions = basis.numDimensions();

    AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2, numDimensions, new L1());

    int numPasses = options.getIntOption('n');
    for (int i = 0; i < numPasses; ++i) {
        Iterator<Result> resultIter = table.iterator(scan);
        while (resultIter.hasNext()) {
            Result row = resultIter.next();

            HypernymStatus status = table.getHypernymStatus(row);
            if (status == HypernymStatus.TERMS_MISSING || status == HypernymStatus.NOVEL_HYPONYM
                    || status == HypernymStatus.NOVEL_HYPERNYM)
                continue;

            // Extract a CompactSparse vector representing the number of
            // dependency paths.
            SparseDoubleVector vector = new CompactSparseVector(numDimensions);
            Counter<String> pathCounts = table.getDependencyPaths(row);
            for (Map.Entry<String, Integer> entry : pathCounts) {
                int dimension = basis.getDimension(entry.getKey());
                if (dimension >= 0)
                    vector.set(dimension, entry.getValue());
            }

            int classLabel = (status == HypernymStatus.KNOWN_HYPERNYM) ? 1 : 0;
            model.train(classLabel, new MahoutSparseVector(vector, numDimensions));
        }
    }

    State<Wrapper, CrossFoldLearner> best = model.getBest();
    if (best == null) {
        System.err.println("The Learner could not be learned");
        System.exit(1);
    }

    OnlineLearner classifier = best.getPayload().getLearner().getModels().get(0);
    SerializableUtil.save(classifier, new File(options.getPositionalArg(0)));
}