Example usage for org.apache.mahout.ep State getPayload

List of usage examples for org.apache.mahout.ep State getPayload

Introduction

In this page you can find the example usage for org.apache.mahout.ep State getPayload.

Prototype

public T getPayload() 

Source Link

Usage

From source file:com.memonews.mahout.sentiment.SGDHelper.java

License:Apache License

static void analyzeState(final SGDInfo info, final int leakType, final int k,
        final State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best) throws IOException {
    final int bump = info.getBumps()[(int) Math.floor(info.getStep()) % info.getBumps().length];
    final int scale = (int) Math.pow(10, Math.floor(info.getStep() / info.getBumps().length));
    double maxBeta;
    double nonZeros;
    double positive;
    double norm;/*from   w  ww  .  j  ava  2s.  c om*/

    double lambda = 0;
    double mu = 0;

    if (best != null) {
        final CrossFoldLearner state = best.getPayload().getLearner();
        info.setAverageCorrect(state.percentCorrect());
        info.setAverageLL(state.logLikelihood());

        final OnlineLogisticRegression model = state.getModels().get(0);
        // finish off pending regularization
        model.close();

        final Matrix beta = model.getBeta();
        maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
        nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
            @Override
            public double apply(final double v) {
                return Math.abs(v) > 1.0e-6 ? 1 : 0;
            }
        });
        positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
            @Override
            public double apply(final double v) {
                return v > 0 ? 1 : 0;
            }
        });
        norm = beta.aggregate(Functions.PLUS, Functions.ABS);

        lambda = best.getMappedParams()[0];
        mu = best.getMappedParams()[1];
    } else {
        maxBeta = 0;
        nonZeros = 0;
        positive = 0;
        norm = 0;
    }
    if (k % (bump * scale) == 0) {
        if (best != null) {
            ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
                    best.getPayload().getLearner().getModels().get(0));
        }

        info.setStep(info.getStep() + 0.25);
        System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda,
                mu);
        System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, info.getAverageLL(), info.getAverageCorrect() * 100,
                LEAK_LABELS[leakType % 3]);
    }
}

From source file:com.tdunning.ch16.train.TrainNewsGroups.java

License:Apache License

public static void main(String[] args) throws IOException {
    File base = new File(args[0]);

    int leakType = 0;
    if (args.length > 1) {
        leakType = Integer.parseInt(args[1]);
    }/*from  w ww  .  j a  va2 s.  c  om*/

    Dictionary newsGroups = new Dictionary();

    encoder.setProbes(2);
    AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, FEATURES, new L1());
    learningAlgorithm.setInterval(800);
    learningAlgorithm.setAveragingWindow(500);

    List<File> files = Lists.newArrayList();
    File[] directories = base.listFiles();
    Arrays.sort(directories, Ordering.usingToString());
    for (File newsgroup : directories) {
        if (newsgroup.isDirectory()) {
            newsGroups.intern(newsgroup.getName());
            files.addAll(Arrays.asList(newsgroup.listFiles()));
        }
    }
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());
    System.out.printf("%s\n", Arrays.asList(directories));

    double averageLL = 0;
    double averageCorrect = 0;

    int k = 0;
    double step = 0;
    int[] bumps = { 1, 2, 5 };
    for (File file : files) {
        String ng = file.getParentFile().getName();
        int actual = newsGroups.intern(ng);

        Vector v = encodeFeatureVector(file);
        learningAlgorithm.train(actual, v);

        k++;

        int bump = bumps[(int) Math.floor(step) % bumps.length];
        int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
        State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
        double maxBeta;
        double nonZeros;
        double positive;
        double norm;

        double lambda = 0;
        double mu = 0;

        if (best != null) {
            CrossFoldLearner state = best.getPayload().getLearner();
            averageCorrect = state.percentCorrect();
            averageLL = state.logLikelihood();

            OnlineLogisticRegression model = state.getModels().get(0);
            // finish off pending regularization
            model.close();

            Matrix beta = model.getBeta();
            maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
            nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
                @Override
                public double apply(double v) {
                    return Math.abs(v) > 1.0e-6 ? 1 : 0;
                }
            });
            positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
                @Override
                public double apply(double v) {
                    return v > 0 ? 1 : 0;
                }
            });
            norm = beta.aggregate(Functions.PLUS, Functions.ABS);

            lambda = learningAlgorithm.getBest().getMappedParams()[0];
            mu = learningAlgorithm.getBest().getMappedParams()[1];
        } else {
            maxBeta = 0;
            nonZeros = 0;
            positive = 0;
            norm = 0;
        }
        if (k % (bump * scale) == 0) {
            if (learningAlgorithm.getBest() != null) {
                ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
                        learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
            }

            step += 0.25;
            System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda,
                    mu);
            System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, averageLL, averageCorrect * 100,
                    LEAK_LABELS[leakType % 3]);
        }
    }
    learningAlgorithm.close();
    dissect(newsGroups, learningAlgorithm, files);
    System.out.println("exiting main");

    ModelSerializer.writeBinary("/tmp/news-group.model",
            learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
}

From source file:gov.llnl.ontology.mains.ExtendWordNet.java

License:Open Source License

/**
 * Trains the hypernym and cousin predictor models based on the evidence
 * gathered for positive and negative relationships.  Returns true if both
 * models could be trained.//w  w  w.  j  ava  2 s  .co  m
 */
public boolean trainModel(int numPasses) {
    // Train the hypernym predictor.
    AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2, basis.numDimensions(), new L1());
    for (int i = 0; i < numPasses; ++i) {
        trainHypernyms(knownPositives.map(), model, 1);
        trainHypernyms(knownNegatives.map(), model, 0);
    }

    // Get the best predictor for hypernyms from the trainer.  If no trainer
    // could be found, return false.
    State<Wrapper, CrossFoldLearner> best = model.getBest();
    if (best == null)
        return false;
    hypernymPredictor = best.getPayload().getLearner().getModels().get(0);

    /*
    // Train the cousin predictor using the similarity scores.
    model = new AdaptiveLogisticRegression(
        2, numSimilarityScores, new L1());
    for (int i = 0; i < numPasses; ++i) {
    trainCousins(knownPositives.map(), model);
    trainCousins(knownNegatives.map(), model);
    }
            
    // Get the best cousin predictor model from the trainer.  If no trainer
    // could be found, return false.
    best = model.getBest();
    if (best == null)
    return false;
    cousinPredictor = best.getPayload().getLearner().getModels().get(0);
    */

    return true;
}

From source file:gov.llnl.ontology.mains.TrainLogisticRegression.java

License:Open Source License

public static void main(String[] args) throws Exception {
    MRArgOptions options = new MRArgOptions();
    options.addOption('w', "wordnetDir", "Specifies the wordnet directory", true, "PATH", "Required");
    options.addOption('b', "basisMapping", "Specifies a serialzied basis mapping", true, "FILE", "Required");
    options.addOption('n', "numPasses", "Specifies the number of training passes to make. " + "(Default: 5)",
            true, "INT", "Optional");
    options.parseOptions(args);//from w  w  w.j  a va  2 s .c om

    if (options.numPositionalArgs() != 1 || !options.hasOption('w') || !options.hasOption('b')) {
        System.out.println("usage: java TrainLogisticRegression [OPTIONS] <out>\n" + options.prettyPrint());
        System.exit(1);
    }

    EvidenceTable table = options.evidenceTable();
    Scan scan = new Scan();
    table.setupScan(scan, options.sourceCorpus());

    StringBasisMapping basis = SerializableUtil.load(new File(options.getStringOption('b')));
    basis.setReadOnly(true);
    int numDimensions = basis.numDimensions();

    AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2, numDimensions, new L1());

    int numPasses = options.getIntOption('n');
    for (int i = 0; i < numPasses; ++i) {
        Iterator<Result> resultIter = table.iterator(scan);
        while (resultIter.hasNext()) {
            Result row = resultIter.next();

            HypernymStatus status = table.getHypernymStatus(row);
            if (status == HypernymStatus.TERMS_MISSING || status == HypernymStatus.NOVEL_HYPONYM
                    || status == HypernymStatus.NOVEL_HYPERNYM)
                continue;

            // Extract a CompactSparse vector representing the number of
            // dependency paths.
            SparseDoubleVector vector = new CompactSparseVector(numDimensions);
            Counter<String> pathCounts = table.getDependencyPaths(row);
            for (Map.Entry<String, Integer> entry : pathCounts) {
                int dimension = basis.getDimension(entry.getKey());
                if (dimension >= 0)
                    vector.set(dimension, entry.getValue());
            }

            int classLabel = (status == HypernymStatus.KNOWN_HYPERNYM) ? 1 : 0;
            model.train(classLabel, new MahoutSparseVector(vector, numDimensions));
        }
    }

    State<Wrapper, CrossFoldLearner> best = model.getBest();
    if (best == null) {
        System.err.println("The Learner could not be learned");
        System.exit(1);
    }

    OnlineLearner classifier = best.getPayload().getLearner().getModels().get(0);
    SerializableUtil.save(classifier, new File(options.getPositionalArg(0)));
}