Example usage for org.apache.mahout.classifier.sgd ModelSerializer readBinary

List of usage examples for org.apache.mahout.classifier.sgd ModelSerializer readBinary

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd ModelSerializer readBinary.

Prototype

public static <T extends Writable> T readBinary(InputStream in, Class<T> clazz) throws IOException 

Source Link

Usage

From source file:com.cloudera.knittingboar.sgd.olr.TestBaseOLRTest20Newsgroups.java

License:Apache License

public void testResults() throws Exception {

    OnlineLogisticRegression classifier = ModelSerializer
            .readBinary(new FileInputStream(model20News.toString()), OnlineLogisticRegression.class);

    Text value = new Text();
    long batch_vec_factory_time = 0;
    int k = 0;//from w  w  w  .  j  a v  a  2  s .c o m
    int num_correct = 0;

    // ---- this all needs to be done in 
    JobConf job = new JobConf(defaultConf);

    // TODO: work on this, splits are generating for everything in dir
    //    InputSplit[] splits = generateDebugSplits(inputDir, job);

    //fullRCV1Dir
    InputSplit[] splits = generateDebugSplits(testData20News, job);

    System.out.println("split count: " + splits.length);

    InputRecordsSplit custom_reader_0 = new InputRecordsSplit(job, splits[0]);

    TwentyNewsgroupsRecordFactory VectorFactory = new TwentyNewsgroupsRecordFactory("\t");

    for (int x = 0; x < 8000; x++) {

        if (custom_reader_0.next(value)) {

            long startTime = System.currentTimeMillis();

            Vector v = new RandomAccessSparseVector(FEATURES);
            int actual = VectorFactory.processLine(value.toString(), v);

            long endTime = System.currentTimeMillis();

            //System.out.println("That took " + (endTime - startTime) + " milliseconds");
            batch_vec_factory_time += (endTime - startTime);

            String ng = VectorFactory.GetClassnameByID(actual); //.GetNewsgroupNameByID( actual );

            // calc stats ---------

            double mu = Math.min(k + 1, 200);
            double ll = classifier.logLikelihood(actual, v);
            //averageLL = averageLL + (ll - averageLL) / mu;
            metrics.AvgLogLikelihood = metrics.AvgLogLikelihood + (ll - metrics.AvgLogLikelihood) / mu;

            Vector p = new DenseVector(20);
            classifier.classifyFull(p, v);
            int estimated = p.maxValueIndex();

            int correct = (estimated == actual ? 1 : 0);
            if (estimated == actual) {
                num_correct++;
            }
            //averageCorrect = averageCorrect + (correct - averageCorrect) / mu;
            metrics.AvgCorrect = metrics.AvgCorrect + (correct - metrics.AvgCorrect) / mu;

            //this.polr.train(actual, v);

            k++;
            //        if (x == this.BatchSize - 1) {
            int bump = bumps[(int) Math.floor(step) % bumps.length];
            int scale = (int) Math.pow(10, Math.floor(step / bumps.length));

            if (k % (bump * scale) == 0) {
                step += 0.25;

                System.out.printf(
                        "Worker %s:\t Tested Recs: %10d, numCorrect: %d, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n",
                        "OLR-standard-test", k, num_correct, metrics.AvgLogLikelihood, metrics.AvgCorrect * 100,
                        batch_vec_factory_time);

            }

            classifier.close();

        } else {

            // nothing else to process in split!
            break;

        } // if

    } // for the number of passes in the run    

}

From source file:com.memonews.mahout.sentiment.SentimentModelTester.java

License:Apache License

public void run(final PrintWriter output) throws IOException {

    final File base = new File(inputFile);
    // contains the best model
    final OnlineLogisticRegression classifier = ModelSerializer.readBinary(new FileInputStream(modelFile),
            OnlineLogisticRegression.class);

    final Dictionary newsGroups = new Dictionary();
    final Multiset<String> overallCounts = HashMultiset.create();

    final List<File> files = Lists.newArrayList();
    for (final File newsgroup : base.listFiles()) {
        if (newsgroup.isDirectory()) {
            newsGroups.intern(newsgroup.getName());
            files.addAll(Arrays.asList(newsgroup.listFiles()));
        }/*w  ww.  j  ava2  s  .  c o  m*/
    }
    System.out.printf("%d test files\n", files.size());
    final ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT");
    for (final File file : files) {
        final String ng = file.getParentFile().getName();

        final int actual = newsGroups.intern(ng);
        final SentimentModelHelper helper = new SentimentModelHelper();
        final Vector input = helper.encodeFeatureVector(file, overallCounts);// no
        // leak
        // type
        // ensures
        // this
        // is
        // a
        // normal
        // vector
        final Vector result = classifier.classifyFull(input);
        final int cat = result.maxValueIndex();
        final double score = result.maxValue();
        final double ll = classifier.logLikelihood(actual, input);
        final ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), score, ll);
        ra.addInstance(newsGroups.values().get(actual), cr);

    }
    output.printf("%s\n\n", ra.toString());
}