Example usage for org.apache.mahout.classifier.sgd SGDInfo getStep

List of usage examples for org.apache.mahout.classifier.sgd SGDInfo getStep

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd SGDInfo getStep.

Prototype

double getStep() 

Source Link

Usage

From source file:com.memonews.mahout.sentiment.SGDHelper.java

License:Apache License

static void analyzeState(final SGDInfo info, final int leakType, final int k,
        final State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best) throws IOException {
    final int bump = info.getBumps()[(int) Math.floor(info.getStep()) % info.getBumps().length];
    final int scale = (int) Math.pow(10, Math.floor(info.getStep() / info.getBumps().length));
    double maxBeta;
    double nonZeros;
    double positive;
    double norm;/*from w  w w  .jav a  2  s  .com*/

    double lambda = 0;
    double mu = 0;

    if (best != null) {
        final CrossFoldLearner state = best.getPayload().getLearner();
        info.setAverageCorrect(state.percentCorrect());
        info.setAverageLL(state.logLikelihood());

        final OnlineLogisticRegression model = state.getModels().get(0);
        // finish off pending regularization
        model.close();

        final Matrix beta = model.getBeta();
        maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
        nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
            @Override
            public double apply(final double v) {
                return Math.abs(v) > 1.0e-6 ? 1 : 0;
            }
        });
        positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
            @Override
            public double apply(final double v) {
                return v > 0 ? 1 : 0;
            }
        });
        norm = beta.aggregate(Functions.PLUS, Functions.ABS);

        lambda = best.getMappedParams()[0];
        mu = best.getMappedParams()[1];
    } else {
        maxBeta = 0;
        nonZeros = 0;
        positive = 0;
        norm = 0;
    }
    if (k % (bump * scale) == 0) {
        if (best != null) {
            ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
                    best.getPayload().getLearner().getModels().get(0));
        }

        info.setStep(info.getStep() + 0.25);
        System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda,
                mu);
        System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, info.getAverageLL(), info.getAverageCorrect() * 100,
                LEAK_LABELS[leakType % 3]);
    }
}