Example usage for org.apache.mahout.ep State get

List of usage examples for org.apache.mahout.ep State get

Introduction

In this page you can find the example usage for org.apache.mahout.ep State get.

Prototype

public double get(int i) 

Source Link

Document

Returns a transformed parameter.

Usage

From source file:com.memonews.mahout.sentiment.SGDHelper.java

License:Apache License

static void analyzeState(final SGDInfo info, final int leakType, final int k,
        final State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best) throws IOException {
    final int bump = info.getBumps()[(int) Math.floor(info.getStep()) % info.getBumps().length];
    final int scale = (int) Math.pow(10, Math.floor(info.getStep() / info.getBumps().length));
    double maxBeta;
    double nonZeros;
    double positive;
    double norm;/*w  ww  . j a  v  a 2  s  .co m*/

    double lambda = 0;
    double mu = 0;

    if (best != null) {
        final CrossFoldLearner state = best.getPayload().getLearner();
        info.setAverageCorrect(state.percentCorrect());
        info.setAverageLL(state.logLikelihood());

        final OnlineLogisticRegression model = state.getModels().get(0);
        // finish off pending regularization
        model.close();

        final Matrix beta = model.getBeta();
        maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
        nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
            @Override
            public double apply(final double v) {
                return Math.abs(v) > 1.0e-6 ? 1 : 0;
            }
        });
        positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
            @Override
            public double apply(final double v) {
                return v > 0 ? 1 : 0;
            }
        });
        norm = beta.aggregate(Functions.PLUS, Functions.ABS);

        lambda = best.getMappedParams()[0];
        mu = best.getMappedParams()[1];
    } else {
        maxBeta = 0;
        nonZeros = 0;
        positive = 0;
        norm = 0;
    }
    if (k % (bump * scale) == 0) {
        if (best != null) {
            ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
                    best.getPayload().getLearner().getModels().get(0));
        }

        info.setStep(info.getStep() + 0.25);
        System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda,
                mu);
        System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, info.getAverageLL(), info.getAverageCorrect() * 100,
                LEAK_LABELS[leakType % 3]);
    }
}

From source file:gov.llnl.ontology.mains.ExtendWordNet.java

License:Open Source License

/**
 * Trains the hypernym and cousin predictor models based on the evidence
 * gathered for positive and negative relationships.  Returns true if both
 * models could be trained.//from w  ww  .j ava2  s. c om
 */
public boolean trainModel(int numPasses) {
    // Train the hypernym predictor.
    AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2, basis.numDimensions(), new L1());
    for (int i = 0; i < numPasses; ++i) {
        trainHypernyms(knownPositives.map(), model, 1);
        trainHypernyms(knownNegatives.map(), model, 0);
    }

    // Get the best predictor for hypernyms from the trainer.  If no trainer
    // could be found, return false.
    State<Wrapper, CrossFoldLearner> best = model.getBest();
    if (best == null)
        return false;
    hypernymPredictor = best.getPayload().getLearner().getModels().get(0);

    /*
    // Train the cousin predictor using the similarity scores.
    model = new AdaptiveLogisticRegression(
        2, numSimilarityScores, new L1());
    for (int i = 0; i < numPasses; ++i) {
    trainCousins(knownPositives.map(), model);
    trainCousins(knownNegatives.map(), model);
    }
            
    // Get the best cousin predictor model from the trainer.  If no trainer
    // could be found, return false.
    best = model.getBest();
    if (best == null)
    return false;
    cousinPredictor = best.getPayload().getLearner().getModels().get(0);
    */

    return true;
}

From source file:gov.llnl.ontology.mains.TrainLogisticRegression.java

License:Open Source License

public static void main(String[] args) throws Exception {
    MRArgOptions options = new MRArgOptions();
    options.addOption('w', "wordnetDir", "Specifies the wordnet directory", true, "PATH", "Required");
    options.addOption('b', "basisMapping", "Specifies a serialzied basis mapping", true, "FILE", "Required");
    options.addOption('n', "numPasses", "Specifies the number of training passes to make. " + "(Default: 5)",
            true, "INT", "Optional");
    options.parseOptions(args);/*from ww w . ja  v  a 2s.  co m*/

    if (options.numPositionalArgs() != 1 || !options.hasOption('w') || !options.hasOption('b')) {
        System.out.println("usage: java TrainLogisticRegression [OPTIONS] <out>\n" + options.prettyPrint());
        System.exit(1);
    }

    EvidenceTable table = options.evidenceTable();
    Scan scan = new Scan();
    table.setupScan(scan, options.sourceCorpus());

    StringBasisMapping basis = SerializableUtil.load(new File(options.getStringOption('b')));
    basis.setReadOnly(true);
    int numDimensions = basis.numDimensions();

    AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2, numDimensions, new L1());

    int numPasses = options.getIntOption('n');
    for (int i = 0; i < numPasses; ++i) {
        Iterator<Result> resultIter = table.iterator(scan);
        while (resultIter.hasNext()) {
            Result row = resultIter.next();

            HypernymStatus status = table.getHypernymStatus(row);
            if (status == HypernymStatus.TERMS_MISSING || status == HypernymStatus.NOVEL_HYPONYM
                    || status == HypernymStatus.NOVEL_HYPERNYM)
                continue;

            // Extract a CompactSparse vector representing the number of
            // dependency paths.
            SparseDoubleVector vector = new CompactSparseVector(numDimensions);
            Counter<String> pathCounts = table.getDependencyPaths(row);
            for (Map.Entry<String, Integer> entry : pathCounts) {
                int dimension = basis.getDimension(entry.getKey());
                if (dimension >= 0)
                    vector.set(dimension, entry.getValue());
            }

            int classLabel = (status == HypernymStatus.KNOWN_HYPERNYM) ? 1 : 0;
            model.train(classLabel, new MahoutSparseVector(vector, numDimensions));
        }
    }

    State<Wrapper, CrossFoldLearner> best = model.getBest();
    if (best == null) {
        System.err.println("The Learner could not be learned");
        System.exit(1);
    }

    OnlineLearner classifier = best.getPayload().getLearner().getModels().get(0);
    SerializableUtil.save(classifier, new File(options.getPositionalArg(0)));
}