Example usage for org.apache.mahout.classifier.sgd AdaptiveLogisticRegression getBest

List of usage examples for org.apache.mahout.classifier.sgd AdaptiveLogisticRegression getBest

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd AdaptiveLogisticRegression getBest.

Prototype

public State<Wrapper, CrossFoldLearner> getBest() 

Source Link

Usage

From source file:com.memonews.mahout.sentiment.SentimentModelTrainer.java

License:Apache License

public static void main(final String[] args) throws IOException {
    final File base = new File(args[0]);
    final String modelPath = args.length > 1 ? args[1] : "target/model";

    final Multiset<String> overallCounts = HashMultiset.create();

    final Dictionary newsGroups = new Dictionary();

    final SentimentModelHelper helper = new SentimentModelHelper();
    helper.getEncoder().setProbes(2);/*  w w w.j a  v a  2s. c  o m*/
    final AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(2,
            SentimentModelHelper.FEATURES, new L1());
    learningAlgorithm.setInterval(800);
    learningAlgorithm.setAveragingWindow(500);

    final List<File> files = Lists.newArrayList();
    for (final File newsgroup : base.listFiles()) {
        if (newsgroup.isDirectory()) {
            newsGroups.intern(newsgroup.getName());
            files.addAll(Arrays.asList(newsgroup.listFiles()));
        }
    }
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());
    final SGDInfo info = new SGDInfo();

    int k = 0;

    for (final File file : files) {
        final String ng = file.getParentFile().getName();
        final int actual = newsGroups.intern(ng);

        final Vector v = helper.encodeFeatureVector(file, overallCounts);
        learningAlgorithm.train(actual, v);

        k++;
        final State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();

        SGDHelper.analyzeState(info, 0, k, best);
    }
    learningAlgorithm.close();
    SGDHelper.dissect(0, newsGroups, learningAlgorithm, files, overallCounts);
    System.out.println("exiting main");

    ModelSerializer.writeBinary(modelPath,
            learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));

    final List<Integer> counts = Lists.newArrayList();
    System.out.printf("Word counts\n");
    for (final String count : overallCounts.elementSet()) {
        counts.add(overallCounts.count(count));
    }
    Collections.sort(counts, Ordering.natural().reverse());
    k = 0;
    for (final Integer count : counts) {
        System.out.printf("%d\t%d\n", k, count);
        k++;
        if (k > 1000) {
            break;
        }
    }
}

From source file:com.memonews.mahout.sentiment.SGDHelper.java

License:Apache License

public static void dissect(final int leakType, final Dictionary dictionary,
        final AdaptiveLogisticRegression learningAlgorithm, final Iterable<File> files,
        final Multiset<String> overallCounts) throws IOException {
    final CrossFoldLearner model = learningAlgorithm.getBest().getPayload().getLearner();
    model.close();// w  w  w  .  j a va 2  s .c  o  m

    final Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
    final ModelDissector md = new ModelDissector();

    final SentimentModelHelper helper = new SentimentModelHelper();
    helper.getEncoder().setTraceDictionary(traceDictionary);
    helper.getBias().setTraceDictionary(traceDictionary);

    for (final File file : permute(files, helper.getRandom()).subList(0, 500)) {
        traceDictionary.clear();
        final Vector v = helper.encodeFeatureVector(file, overallCounts);
        md.update(v, traceDictionary, model);
    }

    final List<String> ngNames = Lists.newArrayList(dictionary.values());
    final List<ModelDissector.Weight> weights = md.summary(100);
    System.out.println("============");
    System.out.println("Model Dissection");
    for (final ModelDissector.Weight w : weights) {
        System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\n", w.getFeature(), w.getWeight(),
                ngNames.get(w.getMaxImpact()), w.getCategory(0), w.getWeight(0));
    }
}

From source file:com.tdunning.ch16.train.TrainNewsGroups.java

License:Apache License

public static void main(String[] args) throws IOException {
    File base = new File(args[0]);

    int leakType = 0;
    if (args.length > 1) {
        leakType = Integer.parseInt(args[1]);
    }/* www  .j  a va 2  s .c  o m*/

    Dictionary newsGroups = new Dictionary();

    encoder.setProbes(2);
    AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, FEATURES, new L1());
    learningAlgorithm.setInterval(800);
    learningAlgorithm.setAveragingWindow(500);

    List<File> files = Lists.newArrayList();
    File[] directories = base.listFiles();
    Arrays.sort(directories, Ordering.usingToString());
    for (File newsgroup : directories) {
        if (newsgroup.isDirectory()) {
            newsGroups.intern(newsgroup.getName());
            files.addAll(Arrays.asList(newsgroup.listFiles()));
        }
    }
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());
    System.out.printf("%s\n", Arrays.asList(directories));

    double averageLL = 0;
    double averageCorrect = 0;

    int k = 0;
    double step = 0;
    int[] bumps = { 1, 2, 5 };
    for (File file : files) {
        String ng = file.getParentFile().getName();
        int actual = newsGroups.intern(ng);

        Vector v = encodeFeatureVector(file);
        learningAlgorithm.train(actual, v);

        k++;

        int bump = bumps[(int) Math.floor(step) % bumps.length];
        int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
        State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
        double maxBeta;
        double nonZeros;
        double positive;
        double norm;

        double lambda = 0;
        double mu = 0;

        if (best != null) {
            CrossFoldLearner state = best.getPayload().getLearner();
            averageCorrect = state.percentCorrect();
            averageLL = state.logLikelihood();

            OnlineLogisticRegression model = state.getModels().get(0);
            // finish off pending regularization
            model.close();

            Matrix beta = model.getBeta();
            maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
            nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
                @Override
                public double apply(double v) {
                    return Math.abs(v) > 1.0e-6 ? 1 : 0;
                }
            });
            positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
                @Override
                public double apply(double v) {
                    return v > 0 ? 1 : 0;
                }
            });
            norm = beta.aggregate(Functions.PLUS, Functions.ABS);

            lambda = learningAlgorithm.getBest().getMappedParams()[0];
            mu = learningAlgorithm.getBest().getMappedParams()[1];
        } else {
            maxBeta = 0;
            nonZeros = 0;
            positive = 0;
            norm = 0;
        }
        if (k % (bump * scale) == 0) {
            if (learningAlgorithm.getBest() != null) {
                ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
                        learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
            }

            step += 0.25;
            System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda,
                    mu);
            System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, averageLL, averageCorrect * 100,
                    LEAK_LABELS[leakType % 3]);
        }
    }
    learningAlgorithm.close();
    dissect(newsGroups, learningAlgorithm, files);
    System.out.println("exiting main");

    ModelSerializer.writeBinary("/tmp/news-group.model",
            learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
}

From source file:com.tdunning.ch16.train.TrainNewsGroups.java

License:Apache License

private static void dissect(Dictionary newsGroups, AdaptiveLogisticRegression learningAlgorithm,
        Iterable<File> files) throws IOException {
    CrossFoldLearner model = learningAlgorithm.getBest().getPayload().getLearner();
    model.close();//from   w ww. ja va2 s .c o m

    Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
    ModelDissector md = new ModelDissector();

    encoder.setTraceDictionary(traceDictionary);
    bias.setTraceDictionary(traceDictionary);

    for (File file : permute(files, rand).subList(0, 500)) {
        traceDictionary.clear();
        Vector v = encodeFeatureVector(file);
        md.update(v, traceDictionary, model);
    }

    List<String> ngNames = Lists.newArrayList(newsGroups.values());
    List<ModelDissector.Weight> weights = md.summary(100);
    for (ModelDissector.Weight w : weights) {
        System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n", w.getFeature(), w.getWeight(),
                ngNames.get(w.getMaxImpact() + 1), w.getCategory(1), w.getWeight(1), w.getCategory(2),
                w.getWeight(2));
    }
}

From source file:gov.llnl.ontology.mains.ExtendWordNet.java

License:Open Source License

/**
 * Trains the hypernym and cousin predictor models based on the evidence
 * gathered for positive and negative relationships.  Returns true if both
 * models could be trained./*  w  w  w.j  av a  2  s.c  om*/
 */
public boolean trainModel(int numPasses) {
    // Train the hypernym predictor.
    AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2, basis.numDimensions(), new L1());
    for (int i = 0; i < numPasses; ++i) {
        trainHypernyms(knownPositives.map(), model, 1);
        trainHypernyms(knownNegatives.map(), model, 0);
    }

    // Get the best predictor for hypernyms from the trainer.  If no trainer
    // could be found, return false.
    State<Wrapper, CrossFoldLearner> best = model.getBest();
    if (best == null)
        return false;
    hypernymPredictor = best.getPayload().getLearner().getModels().get(0);

    /*
    // Train the cousin predictor using the similarity scores.
    model = new AdaptiveLogisticRegression(
        2, numSimilarityScores, new L1());
    for (int i = 0; i < numPasses; ++i) {
    trainCousins(knownPositives.map(), model);
    trainCousins(knownNegatives.map(), model);
    }
            
    // Get the best cousin predictor model from the trainer.  If no trainer
    // could be found, return false.
    best = model.getBest();
    if (best == null)
    return false;
    cousinPredictor = best.getPayload().getLearner().getModels().get(0);
    */

    return true;
}

From source file:gov.llnl.ontology.mains.TrainLogisticRegression.java

License:Open Source License

public static void main(String[] args) throws Exception {
    MRArgOptions options = new MRArgOptions();
    options.addOption('w', "wordnetDir", "Specifies the wordnet directory", true, "PATH", "Required");
    options.addOption('b', "basisMapping", "Specifies a serialzied basis mapping", true, "FILE", "Required");
    options.addOption('n', "numPasses", "Specifies the number of training passes to make. " + "(Default: 5)",
            true, "INT", "Optional");
    options.parseOptions(args);//from  w w  w .ja  v  a  2  s.  co m

    if (options.numPositionalArgs() != 1 || !options.hasOption('w') || !options.hasOption('b')) {
        System.out.println("usage: java TrainLogisticRegression [OPTIONS] <out>\n" + options.prettyPrint());
        System.exit(1);
    }

    EvidenceTable table = options.evidenceTable();
    Scan scan = new Scan();
    table.setupScan(scan, options.sourceCorpus());

    StringBasisMapping basis = SerializableUtil.load(new File(options.getStringOption('b')));
    basis.setReadOnly(true);
    int numDimensions = basis.numDimensions();

    AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2, numDimensions, new L1());

    int numPasses = options.getIntOption('n');
    for (int i = 0; i < numPasses; ++i) {
        Iterator<Result> resultIter = table.iterator(scan);
        while (resultIter.hasNext()) {
            Result row = resultIter.next();

            HypernymStatus status = table.getHypernymStatus(row);
            if (status == HypernymStatus.TERMS_MISSING || status == HypernymStatus.NOVEL_HYPONYM
                    || status == HypernymStatus.NOVEL_HYPERNYM)
                continue;

            // Extract a CompactSparse vector representing the number of
            // dependency paths.
            SparseDoubleVector vector = new CompactSparseVector(numDimensions);
            Counter<String> pathCounts = table.getDependencyPaths(row);
            for (Map.Entry<String, Integer> entry : pathCounts) {
                int dimension = basis.getDimension(entry.getKey());
                if (dimension >= 0)
                    vector.set(dimension, entry.getValue());
            }

            int classLabel = (status == HypernymStatus.KNOWN_HYPERNYM) ? 1 : 0;
            model.train(classLabel, new MahoutSparseVector(vector, numDimensions));
        }
    }

    State<Wrapper, CrossFoldLearner> best = model.getBest();
    if (best == null) {
        System.err.println("The Learner could not be learned");
        System.exit(1);
    }

    OnlineLearner classifier = best.getPayload().getLearner().getModels().get(0);
    SerializableUtil.save(classifier, new File(options.getPositionalArg(0)));
}

From source file:opennlp.addons.mahout.AdaptiveLogisticRegressionTrainer.java

License:Apache License

@Override
public MaxentModel doTrain(DataIndexer indexer) throws IOException {

    // TODO: Lets use the predMap here as well for encoding
    int numberOfOutcomes = indexer.getOutcomeLabels().length;
    int numberOfFeatures = indexer.getPredLabels().length;

    AdaptiveLogisticRegression pa = new AdaptiveLogisticRegression(numberOfOutcomes, numberOfFeatures,
            new L1());

    // TODO: Make these parameters configurable ...
    //  what are good values ?! 
    pa.setInterval(800);/*  w w  w.  ja  v a 2 s.c  o  m*/
    pa.setAveragingWindow(500);

    for (int k = 0; k < iterations; k++) {
        trainOnlineLearner(indexer, pa);

        // What should be reported at the end of every iteration ?!
        System.out.println("Iteration " + (k + 1));
    }

    pa.close();

    return new VectorClassifierModel(pa.getBest().getPayload().getLearner(), indexer.getOutcomeLabels(),
            createPrepMap(indexer));
}

From source file:opennlp.addons.mahout.LogisticRegressionTrainer.java

License:Apache License

@Override
public MaxentModel doTrain(DataIndexer indexer) throws IOException {

    // TODO: Lets use the predMap here as well for encoding

    int outcomes[] = indexer.getOutcomeList();

    int cardinality = indexer.getPredLabels().length;

    AdaptiveLogisticRegression pa = new AdaptiveLogisticRegression(indexer.getOutcomeLabels().length,
            cardinality, new L1());

    pa.setInterval(800);/*from w  w  w.  j  a  va 2  s.c o  m*/
    pa.setAveragingWindow(500);

    //    PassiveAggressive pa = new PassiveAggressive(indexer.getOutcomeLabels().length, cardinality);
    //    pa.learningRate(10000);

    //    OnlineLogisticRegression pa = new OnlineLogisticRegression(indexer.getOutcomeLabels().length, cardinality,
    //        new L1());
    //    
    //    pa.alpha(1).stepOffset(250)
    //    .decayExponent(0.9)
    //    .lambda(3.0e-5)
    //    .learningRate(3000);

    // TODO: Should we do both ?! AdaptiveLogisticRegression ?! 

    for (int k = 0; k < iterations; k++) {
        trainOnlineLearner(indexer, pa);

        // What should be reported at the end of every iteration ?!
        System.out.println("Iteration " + (k + 1));
    }

    pa.close();

    Map<String, Integer> predMap = new HashMap<String, Integer>();

    String predLabels[] = indexer.getPredLabels();
    for (int i = 0; i < predLabels.length; i++) {
        predMap.put(predLabels[i], i);
    }

    return new VectorClassifierModel(pa.getBest().getPayload().getLearner(), indexer.getOutcomeLabels(),
            predMap);

    //    return new VectorClassifierModel(pa, indexer.getOutcomeLabels(), predMap);
}