Example usage for org.apache.mahout.classifier.sgd AdaptiveLogisticRegression close

List of usage examples for org.apache.mahout.classifier.sgd AdaptiveLogisticRegression close

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd AdaptiveLogisticRegression close.

Prototype

@Override
    public void close() 

Source Link

Usage

From source file:com.memonews.mahout.sentiment.SentimentModelTrainer.java

License:Apache License

public static void main(final String[] args) throws IOException {
    final File base = new File(args[0]);
    final String modelPath = args.length > 1 ? args[1] : "target/model";

    final Multiset<String> overallCounts = HashMultiset.create();

    final Dictionary newsGroups = new Dictionary();

    final SentimentModelHelper helper = new SentimentModelHelper();
    helper.getEncoder().setProbes(2);//from   w  ww  .jav  a2  s.  c o m
    final AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(2,
            SentimentModelHelper.FEATURES, new L1());
    learningAlgorithm.setInterval(800);
    learningAlgorithm.setAveragingWindow(500);

    final List<File> files = Lists.newArrayList();
    for (final File newsgroup : base.listFiles()) {
        if (newsgroup.isDirectory()) {
            newsGroups.intern(newsgroup.getName());
            files.addAll(Arrays.asList(newsgroup.listFiles()));
        }
    }
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());
    final SGDInfo info = new SGDInfo();

    int k = 0;

    for (final File file : files) {
        final String ng = file.getParentFile().getName();
        final int actual = newsGroups.intern(ng);

        final Vector v = helper.encodeFeatureVector(file, overallCounts);
        learningAlgorithm.train(actual, v);

        k++;
        final State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();

        SGDHelper.analyzeState(info, 0, k, best);
    }
    learningAlgorithm.close();
    SGDHelper.dissect(0, newsGroups, learningAlgorithm, files, overallCounts);
    System.out.println("exiting main");

    ModelSerializer.writeBinary(modelPath,
            learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));

    final List<Integer> counts = Lists.newArrayList();
    System.out.printf("Word counts\n");
    for (final String count : overallCounts.elementSet()) {
        counts.add(overallCounts.count(count));
    }
    Collections.sort(counts, Ordering.natural().reverse());
    k = 0;
    for (final Integer count : counts) {
        System.out.printf("%d\t%d\n", k, count);
        k++;
        if (k > 1000) {
            break;
        }
    }
}

From source file:com.tdunning.ch16.train.TrainNewsGroups.java

License:Apache License

public static void main(String[] args) throws IOException {
    File base = new File(args[0]);

    int leakType = 0;
    if (args.length > 1) {
        leakType = Integer.parseInt(args[1]);
    }//from  ww w.  j  av  a2s . c om

    Dictionary newsGroups = new Dictionary();

    encoder.setProbes(2);
    AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, FEATURES, new L1());
    learningAlgorithm.setInterval(800);
    learningAlgorithm.setAveragingWindow(500);

    List<File> files = Lists.newArrayList();
    File[] directories = base.listFiles();
    Arrays.sort(directories, Ordering.usingToString());
    for (File newsgroup : directories) {
        if (newsgroup.isDirectory()) {
            newsGroups.intern(newsgroup.getName());
            files.addAll(Arrays.asList(newsgroup.listFiles()));
        }
    }
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());
    System.out.printf("%s\n", Arrays.asList(directories));

    double averageLL = 0;
    double averageCorrect = 0;

    int k = 0;
    double step = 0;
    int[] bumps = { 1, 2, 5 };
    for (File file : files) {
        String ng = file.getParentFile().getName();
        int actual = newsGroups.intern(ng);

        Vector v = encodeFeatureVector(file);
        learningAlgorithm.train(actual, v);

        k++;

        int bump = bumps[(int) Math.floor(step) % bumps.length];
        int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
        State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
        double maxBeta;
        double nonZeros;
        double positive;
        double norm;

        double lambda = 0;
        double mu = 0;

        if (best != null) {
            CrossFoldLearner state = best.getPayload().getLearner();
            averageCorrect = state.percentCorrect();
            averageLL = state.logLikelihood();

            OnlineLogisticRegression model = state.getModels().get(0);
            // finish off pending regularization
            model.close();

            Matrix beta = model.getBeta();
            maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
            nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
                @Override
                public double apply(double v) {
                    return Math.abs(v) > 1.0e-6 ? 1 : 0;
                }
            });
            positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
                @Override
                public double apply(double v) {
                    return v > 0 ? 1 : 0;
                }
            });
            norm = beta.aggregate(Functions.PLUS, Functions.ABS);

            lambda = learningAlgorithm.getBest().getMappedParams()[0];
            mu = learningAlgorithm.getBest().getMappedParams()[1];
        } else {
            maxBeta = 0;
            nonZeros = 0;
            positive = 0;
            norm = 0;
        }
        if (k % (bump * scale) == 0) {
            if (learningAlgorithm.getBest() != null) {
                ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
                        learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
            }

            step += 0.25;
            System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda,
                    mu);
            System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, averageLL, averageCorrect * 100,
                    LEAK_LABELS[leakType % 3]);
        }
    }
    learningAlgorithm.close();
    dissect(newsGroups, learningAlgorithm, files);
    System.out.println("exiting main");

    ModelSerializer.writeBinary("/tmp/news-group.model",
            learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
}

From source file:opennlp.addons.mahout.AdaptiveLogisticRegressionTrainer.java

License:Apache License

@Override
public MaxentModel doTrain(DataIndexer indexer) throws IOException {

    // TODO: Lets use the predMap here as well for encoding
    int numberOfOutcomes = indexer.getOutcomeLabels().length;
    int numberOfFeatures = indexer.getPredLabels().length;

    AdaptiveLogisticRegression pa = new AdaptiveLogisticRegression(numberOfOutcomes, numberOfFeatures,
            new L1());

    // TODO: Make these parameters configurable ...
    //  what are good values ?! 
    pa.setInterval(800);//from w  ww . j  av  a 2s  .  c  om
    pa.setAveragingWindow(500);

    for (int k = 0; k < iterations; k++) {
        trainOnlineLearner(indexer, pa);

        // What should be reported at the end of every iteration ?!
        System.out.println("Iteration " + (k + 1));
    }

    pa.close();

    return new VectorClassifierModel(pa.getBest().getPayload().getLearner(), indexer.getOutcomeLabels(),
            createPrepMap(indexer));
}

From source file:opennlp.addons.mahout.LogisticRegressionTrainer.java

License:Apache License

@Override
public MaxentModel doTrain(DataIndexer indexer) throws IOException {

    // TODO: Lets use the predMap here as well for encoding

    int outcomes[] = indexer.getOutcomeList();

    int cardinality = indexer.getPredLabels().length;

    AdaptiveLogisticRegression pa = new AdaptiveLogisticRegression(indexer.getOutcomeLabels().length,
            cardinality, new L1());

    pa.setInterval(800);/*w  ww.ja  va2  s  . c o m*/
    pa.setAveragingWindow(500);

    //    PassiveAggressive pa = new PassiveAggressive(indexer.getOutcomeLabels().length, cardinality);
    //    pa.learningRate(10000);

    //    OnlineLogisticRegression pa = new OnlineLogisticRegression(indexer.getOutcomeLabels().length, cardinality,
    //        new L1());
    //    
    //    pa.alpha(1).stepOffset(250)
    //    .decayExponent(0.9)
    //    .lambda(3.0e-5)
    //    .learningRate(3000);

    // TODO: Should we do both ?! AdaptiveLogisticRegression ?! 

    for (int k = 0; k < iterations; k++) {
        trainOnlineLearner(indexer, pa);

        // What should be reported at the end of every iteration ?!
        System.out.println("Iteration " + (k + 1));
    }

    pa.close();

    Map<String, Integer> predMap = new HashMap<String, Integer>();

    String predLabels[] = indexer.getPredLabels();
    for (int i = 0; i < predLabels.length; i++) {
        predMap.put(predLabels[i], i);
    }

    return new VectorClassifierModel(pa.getBest().getPayload().getLearner(), indexer.getOutcomeLabels(),
            predMap);

    //    return new VectorClassifierModel(pa, indexer.getOutcomeLabels(), predMap);
}