Example usage for org.apache.mahout.classifier.sgd SGDInfo setStep

List of usage examples for org.apache.mahout.classifier.sgd SGDInfo setStep

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd SGDInfo setStep.

Prototype

void setStep(double step) 

Source Link

Usage

From source file:com.memonews.mahout.sentiment.SGDHelper.java

License:Apache License

static void analyzeState(final SGDInfo info, final int leakType, final int k,
        final State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best) throws IOException {
    final int bump = info.getBumps()[(int) Math.floor(info.getStep()) % info.getBumps().length];
    final int scale = (int) Math.pow(10, Math.floor(info.getStep() / info.getBumps().length));
    double maxBeta;
    double nonZeros;
    double positive;
    double norm;/*from   ww  w . j  av a2 s.  c o m*/

    double lambda = 0;
    double mu = 0;

    if (best != null) {
        final CrossFoldLearner state = best.getPayload().getLearner();
        info.setAverageCorrect(state.percentCorrect());
        info.setAverageLL(state.logLikelihood());

        final OnlineLogisticRegression model = state.getModels().get(0);
        // finish off pending regularization
        model.close();

        final Matrix beta = model.getBeta();
        maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
        nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
            @Override
            public double apply(final double v) {
                return Math.abs(v) > 1.0e-6 ? 1 : 0;
            }
        });
        positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
            @Override
            public double apply(final double v) {
                return v > 0 ? 1 : 0;
            }
        });
        norm = beta.aggregate(Functions.PLUS, Functions.ABS);

        lambda = best.getMappedParams()[0];
        mu = best.getMappedParams()[1];
    } else {
        maxBeta = 0;
        nonZeros = 0;
        positive = 0;
        norm = 0;
    }
    if (k % (bump * scale) == 0) {
        if (best != null) {
            ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
                    best.getPayload().getLearner().getModels().get(0));
        }

        info.setStep(info.getStep() + 0.25);
        System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda,
                mu);
        System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, info.getAverageLL(), info.getAverageCorrect() * 100,
                LEAK_LABELS[leakType % 3]);
    }
}