Example usage for org.apache.mahout.classifier.sgd AdaptiveLogisticRegression AdaptiveLogisticRegression

List of usage examples for org.apache.mahout.classifier.sgd AdaptiveLogisticRegression AdaptiveLogisticRegression

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd AdaptiveLogisticRegression AdaptiveLogisticRegression.

Prototype

public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior) 

Source Link

Document

Uses #DEFAULT_THREAD_COUNT and #DEFAULT_POOL_SIZE

Usage

From source file:com.memonews.mahout.sentiment.SentimentModelTrainer.java

License:Apache License

public static void main(final String[] args) throws IOException {
    final File base = new File(args[0]);
    final String modelPath = args.length > 1 ? args[1] : "target/model";

    final Multiset<String> overallCounts = HashMultiset.create();

    final Dictionary newsGroups = new Dictionary();

    final SentimentModelHelper helper = new SentimentModelHelper();
    helper.getEncoder().setProbes(2);/*  ww w .  j a  v  a2s  . c o m*/
    final AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(2,
            SentimentModelHelper.FEATURES, new L1());
    learningAlgorithm.setInterval(800);
    learningAlgorithm.setAveragingWindow(500);

    final List<File> files = Lists.newArrayList();
    for (final File newsgroup : base.listFiles()) {
        if (newsgroup.isDirectory()) {
            newsGroups.intern(newsgroup.getName());
            files.addAll(Arrays.asList(newsgroup.listFiles()));
        }
    }
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());
    final SGDInfo info = new SGDInfo();

    int k = 0;

    for (final File file : files) {
        final String ng = file.getParentFile().getName();
        final int actual = newsGroups.intern(ng);

        final Vector v = helper.encodeFeatureVector(file, overallCounts);
        learningAlgorithm.train(actual, v);

        k++;
        final State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();

        SGDHelper.analyzeState(info, 0, k, best);
    }
    learningAlgorithm.close();
    SGDHelper.dissect(0, newsGroups, learningAlgorithm, files, overallCounts);
    System.out.println("exiting main");

    ModelSerializer.writeBinary(modelPath,
            learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));

    final List<Integer> counts = Lists.newArrayList();
    System.out.printf("Word counts\n");
    for (final String count : overallCounts.elementSet()) {
        counts.add(overallCounts.count(count));
    }
    Collections.sort(counts, Ordering.natural().reverse());
    k = 0;
    for (final Integer count : counts) {
        System.out.printf("%d\t%d\n", k, count);
        k++;
        if (k > 1000) {
            break;
        }
    }
}

From source file:com.ml.ira.algos.AdaptiveLogisticModelParameters.java

License:Apache License

public AdaptiveLogisticRegression createAdaptiveLogisticRegression() {

    if (alr == null) {
        alr = new AdaptiveLogisticRegression(getMaxTargetCategories(), getNumFeatures(),
                createPrior(prior, priorOption));
        alr.setInterval(interval);/*from ww  w .  ja va  2  s .  co  m*/
        alr.setAveragingWindow(averageWindow);
        alr.setThreadCount(threads);
        alr.setAucEvaluator(createAUC(auc));
    }
    return alr;
}

From source file:com.tdunning.ch16.train.TrainNewsGroups.java

License:Apache License

public static void main(String[] args) throws IOException {
    File base = new File(args[0]);

    int leakType = 0;
    if (args.length > 1) {
        leakType = Integer.parseInt(args[1]);
    }//  ww w  .j  ava  2s.co  m

    Dictionary newsGroups = new Dictionary();

    encoder.setProbes(2);
    AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, FEATURES, new L1());
    learningAlgorithm.setInterval(800);
    learningAlgorithm.setAveragingWindow(500);

    List<File> files = Lists.newArrayList();
    File[] directories = base.listFiles();
    Arrays.sort(directories, Ordering.usingToString());
    for (File newsgroup : directories) {
        if (newsgroup.isDirectory()) {
            newsGroups.intern(newsgroup.getName());
            files.addAll(Arrays.asList(newsgroup.listFiles()));
        }
    }
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());
    System.out.printf("%s\n", Arrays.asList(directories));

    double averageLL = 0;
    double averageCorrect = 0;

    int k = 0;
    double step = 0;
    int[] bumps = { 1, 2, 5 };
    for (File file : files) {
        String ng = file.getParentFile().getName();
        int actual = newsGroups.intern(ng);

        Vector v = encodeFeatureVector(file);
        learningAlgorithm.train(actual, v);

        k++;

        int bump = bumps[(int) Math.floor(step) % bumps.length];
        int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
        State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
        double maxBeta;
        double nonZeros;
        double positive;
        double norm;

        double lambda = 0;
        double mu = 0;

        if (best != null) {
            CrossFoldLearner state = best.getPayload().getLearner();
            averageCorrect = state.percentCorrect();
            averageLL = state.logLikelihood();

            OnlineLogisticRegression model = state.getModels().get(0);
            // finish off pending regularization
            model.close();

            Matrix beta = model.getBeta();
            maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
            nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
                @Override
                public double apply(double v) {
                    return Math.abs(v) > 1.0e-6 ? 1 : 0;
                }
            });
            positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
                @Override
                public double apply(double v) {
                    return v > 0 ? 1 : 0;
                }
            });
            norm = beta.aggregate(Functions.PLUS, Functions.ABS);

            lambda = learningAlgorithm.getBest().getMappedParams()[0];
            mu = learningAlgorithm.getBest().getMappedParams()[1];
        } else {
            maxBeta = 0;
            nonZeros = 0;
            positive = 0;
            norm = 0;
        }
        if (k % (bump * scale) == 0) {
            if (learningAlgorithm.getBest() != null) {
                ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
                        learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
            }

            step += 0.25;
            System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda,
                    mu);
            System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, averageLL, averageCorrect * 100,
                    LEAK_LABELS[leakType % 3]);
        }
    }
    learningAlgorithm.close();
    dissect(newsGroups, learningAlgorithm, files);
    System.out.println("exiting main");

    ModelSerializer.writeBinary("/tmp/news-group.model",
            learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
}

From source file:gov.llnl.ontology.mains.ExtendWordNet.java

License:Open Source License

/**
 * Trains the hypernym and cousin predictor models based on the evidence
 * gathered for positive and negative relationships.  Returns true if both
 * models could be trained.//from  w  w w . j av  a  2 s  .  c  o  m
 */
public boolean trainModel(int numPasses) {
    // Train the hypernym predictor.
    AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2, basis.numDimensions(), new L1());
    for (int i = 0; i < numPasses; ++i) {
        trainHypernyms(knownPositives.map(), model, 1);
        trainHypernyms(knownNegatives.map(), model, 0);
    }

    // Get the best predictor for hypernyms from the trainer.  If no trainer
    // could be found, return false.
    State<Wrapper, CrossFoldLearner> best = model.getBest();
    if (best == null)
        return false;
    hypernymPredictor = best.getPayload().getLearner().getModels().get(0);

    /*
    // Train the cousin predictor using the similarity scores.
    model = new AdaptiveLogisticRegression(
        2, numSimilarityScores, new L1());
    for (int i = 0; i < numPasses; ++i) {
    trainCousins(knownPositives.map(), model);
    trainCousins(knownNegatives.map(), model);
    }
            
    // Get the best cousin predictor model from the trainer.  If no trainer
    // could be found, return false.
    best = model.getBest();
    if (best == null)
    return false;
    cousinPredictor = best.getPayload().getLearner().getModels().get(0);
    */

    return true;
}

From source file:gov.llnl.ontology.mains.TrainLogisticRegression.java

License:Open Source License

public static void main(String[] args) throws Exception {
    MRArgOptions options = new MRArgOptions();
    options.addOption('w', "wordnetDir", "Specifies the wordnet directory", true, "PATH", "Required");
    options.addOption('b', "basisMapping", "Specifies a serialzied basis mapping", true, "FILE", "Required");
    options.addOption('n', "numPasses", "Specifies the number of training passes to make. " + "(Default: 5)",
            true, "INT", "Optional");
    options.parseOptions(args);/*from w  ww.  j ava  2 s . com*/

    if (options.numPositionalArgs() != 1 || !options.hasOption('w') || !options.hasOption('b')) {
        System.out.println("usage: java TrainLogisticRegression [OPTIONS] <out>\n" + options.prettyPrint());
        System.exit(1);
    }

    EvidenceTable table = options.evidenceTable();
    Scan scan = new Scan();
    table.setupScan(scan, options.sourceCorpus());

    StringBasisMapping basis = SerializableUtil.load(new File(options.getStringOption('b')));
    basis.setReadOnly(true);
    int numDimensions = basis.numDimensions();

    AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2, numDimensions, new L1());

    int numPasses = options.getIntOption('n');
    for (int i = 0; i < numPasses; ++i) {
        Iterator<Result> resultIter = table.iterator(scan);
        while (resultIter.hasNext()) {
            Result row = resultIter.next();

            HypernymStatus status = table.getHypernymStatus(row);
            if (status == HypernymStatus.TERMS_MISSING || status == HypernymStatus.NOVEL_HYPONYM
                    || status == HypernymStatus.NOVEL_HYPERNYM)
                continue;

            // Extract a CompactSparse vector representing the number of
            // dependency paths.
            SparseDoubleVector vector = new CompactSparseVector(numDimensions);
            Counter<String> pathCounts = table.getDependencyPaths(row);
            for (Map.Entry<String, Integer> entry : pathCounts) {
                int dimension = basis.getDimension(entry.getKey());
                if (dimension >= 0)
                    vector.set(dimension, entry.getValue());
            }

            int classLabel = (status == HypernymStatus.KNOWN_HYPERNYM) ? 1 : 0;
            model.train(classLabel, new MahoutSparseVector(vector, numDimensions));
        }
    }

    State<Wrapper, CrossFoldLearner> best = model.getBest();
    if (best == null) {
        System.err.println("The Learner could not be learned");
        System.exit(1);
    }

    OnlineLearner classifier = best.getPayload().getLearner().getModels().get(0);
    SerializableUtil.save(classifier, new File(options.getPositionalArg(0)));
}

From source file:opennlp.addons.mahout.AdaptiveLogisticRegressionTrainer.java

License:Apache License

@Override
public MaxentModel doTrain(DataIndexer indexer) throws IOException {

    // TODO: Lets use the predMap here as well for encoding
    int numberOfOutcomes = indexer.getOutcomeLabels().length;
    int numberOfFeatures = indexer.getPredLabels().length;

    AdaptiveLogisticRegression pa = new AdaptiveLogisticRegression(numberOfOutcomes, numberOfFeatures,
            new L1());

    // TODO: Make these parameters configurable ...
    //  what are good values ?! 
    pa.setInterval(800);//from w w  w  .  java2 s .  c  o  m
    pa.setAveragingWindow(500);

    for (int k = 0; k < iterations; k++) {
        trainOnlineLearner(indexer, pa);

        // What should be reported at the end of every iteration ?!
        System.out.println("Iteration " + (k + 1));
    }

    pa.close();

    return new VectorClassifierModel(pa.getBest().getPayload().getLearner(), indexer.getOutcomeLabels(),
            createPrepMap(indexer));
}

From source file:opennlp.addons.mahout.LogisticRegressionTrainer.java

License:Apache License

@Override
public MaxentModel doTrain(DataIndexer indexer) throws IOException {

    // TODO: Lets use the predMap here as well for encoding

    int outcomes[] = indexer.getOutcomeList();

    int cardinality = indexer.getPredLabels().length;

    AdaptiveLogisticRegression pa = new AdaptiveLogisticRegression(indexer.getOutcomeLabels().length,
            cardinality, new L1());

    pa.setInterval(800);/*from  w  w  w  .j  a  va2s.  c o  m*/
    pa.setAveragingWindow(500);

    //    PassiveAggressive pa = new PassiveAggressive(indexer.getOutcomeLabels().length, cardinality);
    //    pa.learningRate(10000);

    //    OnlineLogisticRegression pa = new OnlineLogisticRegression(indexer.getOutcomeLabels().length, cardinality,
    //        new L1());
    //    
    //    pa.alpha(1).stepOffset(250)
    //    .decayExponent(0.9)
    //    .lambda(3.0e-5)
    //    .learningRate(3000);

    // TODO: Should we do both ?! AdaptiveLogisticRegression ?! 

    for (int k = 0; k < iterations; k++) {
        trainOnlineLearner(indexer, pa);

        // What should be reported at the end of every iteration ?!
        System.out.println("Iteration " + (k + 1));
    }

    pa.close();

    Map<String, Integer> predMap = new HashMap<String, Integer>();

    String predLabels[] = indexer.getPredLabels();
    for (int i = 0; i < predLabels.length; i++) {
        predMap.put(predLabels[i], i);
    }

    return new VectorClassifierModel(pa.getBest().getPayload().getLearner(), indexer.getOutcomeLabels(),
            predMap);

    //    return new VectorClassifierModel(pa, indexer.getOutcomeLabels(), predMap);
}