Example usage for org.apache.mahout.classifier.sgd OnlineLogisticRegression logLikelihood

List of usage examples for org.apache.mahout.classifier.sgd OnlineLogisticRegression logLikelihood

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd OnlineLogisticRegression logLikelihood.

Prototype

public double logLikelihood(int actual, Vector data) 

Source Link

Document

Returns a measure of how good the classification for a particular example actually is.

Usage

From source file:TrainLogistic.java

License:Apache License

static void mainToOutput(String[] args, PrintWriter output) throws Exception {
    if (parseArgs(args)) {
        double logPEstimate = 0;
        int samples = 0;
        /*read files in dir of inputFile*/
        int fi = 0;//file ID
        File file = new File(inputFile);
        String[] fns = file.list(new FilenameFilter() {
            public boolean accept(File dir, String name) {
                if (name.endsWith(".svm")) {
                    return true;
                } else {
                    return false;
                }/*  ww w  .  j  a  va2 s  .c om*/
            }
        });

        String[] ss = new String[lmp.getNumFeatures() + 1];
        String[] iv = new String[2];
        OnlineLogisticRegression lr = lmp.createRegression();
        while (fi < fns.length) {
            for (int pass = 0; pass < passes; pass++) {
                BufferedReader in = open(inputFile + fns[fi]);
                System.out.println(pass + 1);
                try {
                    // read variable names

                    String line = in.readLine();
                    int lineCount = 1;
                    while (line != null) {
                        // for each new line, get target and predictors
                        Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                        ss = line.split(" ");
                        int targetValue;
                        if (ss[0].startsWith("+"))
                            targetValue = 1;
                        else
                            targetValue = 0;
                        int k = 1;
                        while (k < ss.length) {
                            iv = ss[k].split(":");
                            input.setQuick(Integer.valueOf(iv[0]) - 1, Double.valueOf(iv[1]));
                            //System.out.printf("%d-----%d:%.4f====%d\n", k,Integer.valueOf(iv[0])-1,Double.valueOf(iv[1]),lineCount);
                            k++;
                        }
                        input.setQuick(lmp.getNumFeatures() - 1, 1);
                        // check performance while this is still news
                        double logP = lr.logLikelihood(targetValue, input);
                        if (!Double.isInfinite(logP)) {
                            if (samples < 20) {
                                logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
                            } else {
                                logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
                            }
                            samples++;
                        }
                        double p = lr.classifyScalar(input);
                        if (scores) {
                            output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f\n", samples,
                                    targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
                        }
                        // now update model
                        lr.train(targetValue, input);
                        if ((lineCount) % 1000 == 0)
                            System.out.printf("%d\t", lineCount);
                        line = in.readLine();
                        lineCount++;
                    }
                } finally {
                    Closeables.closeQuietly(in);
                }
                System.out.println();
            }
            fi++;
        }

        FileOutputStream modelOutput = new FileOutputStream(outputFile);
        try {
            saveTo(modelOutput, lr);
        } finally {
            Closeables.closeQuietly(modelOutput);
        }
        /*
              output.printf(Locale.ENGLISH, "%d\n", lmp.getNumFeatures());
              output.printf(Locale.ENGLISH, "%s ~ ", lmp.getTargetVariable());
              String sep = "";
              for (String v : csv.getTraceDictionary().keySet()) {
                double weight = predictorWeight(lr, 0, csv, v);
                if (weight != 0) {
                  output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
                  sep = " + ";
                }
              }
              output.printf("\n");
              model = lr;
              for (int row = 0; row < lr.getBeta().numRows(); row++) {
                for (String key : csv.getTraceDictionary().keySet()) {
                  double weight = predictorWeight(lr, row, csv, key);
                  if (weight != 0) {
                    output.printf(Locale.ENGLISH, "%20s %.5f\n", key, weight);
                  }
                }
                for (int column = 0; column < lr.getBeta().numCols(); column++) {
                  output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
                }
                output.println();
              }*/
    }
}

From source file:com.cloudera.knittingboar.records.TestTwentyNewsgroupsCustomRecordParseOLRRun.java

License:Apache License

@Test
public void testRecordFactoryOnDatasetShard() throws Exception {
    // TODO a test with assertions is not a test
    // p.270 ----- metrics to track lucene's parsing mechanics, progress,
    // performance of OLR ------------
    double averageLL = 0.0;
    double averageCorrect = 0.0;
    int k = 0;/*from   ww w.j  ava  2s .  c o  m*/
    double step = 0.0;
    int[] bumps = new int[] { 1, 2, 5 };

    TwentyNewsgroupsRecordFactory rec_factory = new TwentyNewsgroupsRecordFactory("\t");
    // rec_factory.setClassSplitString("\t");

    JobConf job = new JobConf(defaultConf);

    long block_size = localFs.getDefaultBlockSize(workDir);

    LOG.info("default block size: " + (block_size / 1024 / 1024) + "MB");

    // matches the OLR setup on p.269 ---------------
    // stepOffset, decay, and alpha --- describe how the learning rate decreases
    // lambda: amount of regularization
    // learningRate: amount of initial learning rate
    @SuppressWarnings("resource")
    OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression(20, FEATURES, new L1()).alpha(1)
            .stepOffset(1000).decayExponent(0.9).lambda(3.0e-5).learningRate(20);

    FileInputFormat.setInputPaths(job, workDir);

    // try splitting the file in a variety of sizes
    TextInputFormat format = new TextInputFormat();
    format.configure(job);
    Text value = new Text();

    int numSplits = 1;

    InputSplit[] splits = format.getSplits(job, numSplits);

    LOG.info("requested " + numSplits + " splits, splitting: got =        " + splits.length);
    LOG.info("---- debug splits --------- ");
    rec_factory.Debug();
    int total_read = 0;

    for (int x = 0; x < splits.length; x++) {

        LOG.info("> Split [" + x + "]: " + splits[x].getLength());

        int count = 0;
        InputRecordsSplit custom_reader = new InputRecordsSplit(job, splits[x]);
        while (custom_reader.next(value)) {
            Vector v = new RandomAccessSparseVector(TwentyNewsgroupsRecordFactory.FEATURES);
            int actual = rec_factory.processLine(value.toString(), v);

            String ng = rec_factory.GetNewsgroupNameByID(actual);

            // calc stats ---------

            double mu = Math.min(k + 1, 200);
            double ll = learningAlgorithm.logLikelihood(actual, v);
            averageLL = averageLL + (ll - averageLL) / mu;

            Vector p = new DenseVector(20);
            learningAlgorithm.classifyFull(p, v);
            int estimated = p.maxValueIndex();

            int correct = (estimated == actual ? 1 : 0);
            averageCorrect = averageCorrect + (correct - averageCorrect) / mu;
            learningAlgorithm.train(actual, v);
            k++;
            int bump = bumps[(int) Math.floor(step) % bumps.length];
            int scale = (int) Math.pow(10, Math.floor(step / bumps.length));

            if (k % (bump * scale) == 0) {
                step += 0.25;
                LOG.info(String.format("%10d %10.3f %10.3f %10.2f %s %s", k, ll, averageLL,
                        averageCorrect * 100, ng, rec_factory.GetNewsgroupNameByID(estimated)));
            }

            learningAlgorithm.close();
            count++;
        }

        LOG.info("read: " + count + " records for split " + x);
        total_read += count;
    } // for each split
    LOG.info("total read across all splits: " + total_read);
    rec_factory.Debug();
}

From source file:com.cloudera.knittingboar.sgd.olr.TestBaseOLRTest20Newsgroups.java

License:Apache License

public void testResults() throws Exception {

    OnlineLogisticRegression classifier = ModelSerializer
            .readBinary(new FileInputStream(model20News.toString()), OnlineLogisticRegression.class);

    Text value = new Text();
    long batch_vec_factory_time = 0;
    int k = 0;//  w ww.jav  a 2  s .com
    int num_correct = 0;

    // ---- this all needs to be done in 
    JobConf job = new JobConf(defaultConf);

    // TODO: work on this, splits are generating for everything in dir
    //    InputSplit[] splits = generateDebugSplits(inputDir, job);

    //fullRCV1Dir
    InputSplit[] splits = generateDebugSplits(testData20News, job);

    System.out.println("split count: " + splits.length);

    InputRecordsSplit custom_reader_0 = new InputRecordsSplit(job, splits[0]);

    TwentyNewsgroupsRecordFactory VectorFactory = new TwentyNewsgroupsRecordFactory("\t");

    for (int x = 0; x < 8000; x++) {

        if (custom_reader_0.next(value)) {

            long startTime = System.currentTimeMillis();

            Vector v = new RandomAccessSparseVector(FEATURES);
            int actual = VectorFactory.processLine(value.toString(), v);

            long endTime = System.currentTimeMillis();

            //System.out.println("That took " + (endTime - startTime) + " milliseconds");
            batch_vec_factory_time += (endTime - startTime);

            String ng = VectorFactory.GetClassnameByID(actual); //.GetNewsgroupNameByID( actual );

            // calc stats ---------

            double mu = Math.min(k + 1, 200);
            double ll = classifier.logLikelihood(actual, v);
            //averageLL = averageLL + (ll - averageLL) / mu;
            metrics.AvgLogLikelihood = metrics.AvgLogLikelihood + (ll - metrics.AvgLogLikelihood) / mu;

            Vector p = new DenseVector(20);
            classifier.classifyFull(p, v);
            int estimated = p.maxValueIndex();

            int correct = (estimated == actual ? 1 : 0);
            if (estimated == actual) {
                num_correct++;
            }
            //averageCorrect = averageCorrect + (correct - averageCorrect) / mu;
            metrics.AvgCorrect = metrics.AvgCorrect + (correct - metrics.AvgCorrect) / mu;

            //this.polr.train(actual, v);

            k++;
            //        if (x == this.BatchSize - 1) {
            int bump = bumps[(int) Math.floor(step) % bumps.length];
            int scale = (int) Math.pow(10, Math.floor(step / bumps.length));

            if (k % (bump * scale) == 0) {
                step += 0.25;

                System.out.printf(
                        "Worker %s:\t Tested Recs: %10d, numCorrect: %d, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n",
                        "OLR-standard-test", k, num_correct, metrics.AvgLogLikelihood, metrics.AvgCorrect * 100,
                        batch_vec_factory_time);

            }

            classifier.close();

        } else {

            // nothing else to process in split!
            break;

        } // if

    } // for the number of passes in the run    

}

From source file:com.cloudera.knittingboar.sgd.olr.TestBaseOLR_Train20Newsgroups.java

License:Apache License

public void testTrainNewsGroups() throws IOException {

    File base = new File("/Users/jpatterson/Downloads/datasets/20news-bydate/20news-bydate-train/");
    overallCounts = HashMultiset.create();

    long startTime = System.currentTimeMillis();

    // p.269 ---------------------------------------------------------
    Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>();

    // encodes the text content in both the subject and the body of the email
    FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");
    encoder.setProbes(2);//from   w ww.j  a  va  2  s .  com
    encoder.setTraceDictionary(traceDictionary);

    // provides a constant offset that the model can use to encode the average frequency 
    // of each class
    FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
    bias.setTraceDictionary(traceDictionary);

    // used to encode the number of lines in a message
    FeatureVectorEncoder lines = new ConstantValueEncoder("Lines");
    lines.setTraceDictionary(traceDictionary);

    FeatureVectorEncoder logLines = new ConstantValueEncoder("LogLines");
    logLines.setTraceDictionary(traceDictionary);

    Dictionary newsGroups = new Dictionary();

    // matches the OLR setup on p.269 ---------------
    // stepOffset, decay, and alpha --- describe how the learning rate decreases
    // lambda: amount of regularization
    // learningRate: amount of initial learning rate
    OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression(20, FEATURES, new L1()).alpha(1)
            .stepOffset(1000).decayExponent(0.9).lambda(3.0e-5).learningRate(20);

    // bottom of p.269 ------------------------------
    // because OLR expects to get integer class IDs for the target variable during training
    // we need a dictionary to convert the target variable (the newsgroup name)
    // to an integer, which is the newsGroup object
    List<File> files = new ArrayList<File>();
    for (File newsgroup : base.listFiles()) {
        newsGroups.intern(newsgroup.getName());
        files.addAll(Arrays.asList(newsgroup.listFiles()));
    }

    // mix up the files, helps training in OLR
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());

    // p.270 ----- metrics to track lucene's parsing mechanics, progress, performance of OLR ------------
    double averageLL = 0.0;
    double averageCorrect = 0.0;
    double averageLineCount = 0.0;
    int k = 0;
    double step = 0.0;
    int[] bumps = new int[] { 1, 2, 5 };
    double lineCount = 0;

    // last line on p.269
    Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31);

    Splitter onColon = Splitter.on(":").trimResults();

    int input_file_count = 0;

    // ----- p.270 ------------ "reading and tokenzing the data" ---------
    for (File file : files) {
        BufferedReader reader = new BufferedReader(new FileReader(file));

        input_file_count++;

        // identify newsgroup ----------------
        // convert newsgroup name to unique id
        // -----------------------------------
        String ng = file.getParentFile().getName();
        int actual = newsGroups.intern(ng);
        Multiset<String> words = ConcurrentHashMultiset.create();

        // check for line count header -------
        String line = reader.readLine();
        while (line != null && line.length() > 0) {

            // if this is a line that has a line count, let's pull that value out ------
            if (line.startsWith("Lines:")) {
                String count = Iterables.get(onColon.split(line), 1);
                try {
                    lineCount = Integer.parseInt(count);
                    averageLineCount += (lineCount - averageLineCount) / Math.min(k + 1, 1000);
                } catch (NumberFormatException e) {
                    // if anything goes wrong in parse: just use the avg count
                    lineCount = averageLineCount;
                }
            }

            boolean countHeader = (line.startsWith("From:") || line.startsWith("Subject:")
                    || line.startsWith("Keywords:") || line.startsWith("Summary:"));

            // loop through the lines in the file, while the line starts with: " "
            do {

                // get a reader for this specific string ------
                StringReader in = new StringReader(line);

                // ---- count words in header ---------            
                if (countHeader) {
                    countWords(analyzer, words, in);
                }

                // iterate to the next string ----
                line = reader.readLine();

            } while (line.startsWith(" "));

        } // while (lines in header) {

        //  -------- count words in body ----------
        countWords(analyzer, words, reader);
        reader.close();

        // ----- p.271 -----------
        Vector v = new RandomAccessSparseVector(FEATURES);

        // original value does nothing in a ContantValueEncoder
        bias.addToVector("", 1, v);

        // original value does nothing in a ContantValueEncoder
        lines.addToVector("", lineCount / 30, v);

        // original value does nothing in a ContantValueEncoder        
        logLines.addToVector("", Math.log(lineCount + 1), v);

        // now scan through all the words and add them
        for (String word : words.elementSet()) {
            encoder.addToVector(word, Math.log(1 + words.count(word)), v);
        }

        //Utils.PrintVectorNonZero(v);

        // calc stats ---------

        double mu = Math.min(k + 1, 200);
        double ll = learningAlgorithm.logLikelihood(actual, v);
        averageLL = averageLL + (ll - averageLL) / mu;

        Vector p = new DenseVector(20);
        learningAlgorithm.classifyFull(p, v);
        int estimated = p.maxValueIndex();

        int correct = (estimated == actual ? 1 : 0);
        averageCorrect = averageCorrect + (correct - averageCorrect) / mu;

        learningAlgorithm.train(actual, v);

        k++;

        int bump = bumps[(int) Math.floor(step) % bumps.length];
        int scale = (int) Math.pow(10, Math.floor(step / bumps.length));

        if (k % (bump * scale) == 0) {
            step += 0.25;
            System.out.printf("%10d %10.3f %10.3f %10.2f %s %s\n", k, ll, averageLL, averageCorrect * 100, ng,
                    newsGroups.values().get(estimated));
        }

        learningAlgorithm.close();

        /*    if (k>4) {
              break;
            }
          */

    }

    Utils.PrintVectorSection(learningAlgorithm.getBeta().viewRow(0), 3);

    long endTime = System.currentTimeMillis();

    //System.out.println("That took " + (endTime - startTime) + " milliseconds");
    long duration = (endTime - startTime);

    System.out.println("Processed Input Files: " + input_file_count + ", time: " + duration + "ms");

    ModelSerializer.writeBinary("/tmp/olr-news-group.model", learningAlgorithm);
    // learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));

}

From source file:com.memonews.mahout.sentiment.SentimentModelTester.java

License:Apache License

public void run(final PrintWriter output) throws IOException {

    final File base = new File(inputFile);
    // contains the best model
    final OnlineLogisticRegression classifier = ModelSerializer.readBinary(new FileInputStream(modelFile),
            OnlineLogisticRegression.class);

    final Dictionary newsGroups = new Dictionary();
    final Multiset<String> overallCounts = HashMultiset.create();

    final List<File> files = Lists.newArrayList();
    for (final File newsgroup : base.listFiles()) {
        if (newsgroup.isDirectory()) {
            newsGroups.intern(newsgroup.getName());
            files.addAll(Arrays.asList(newsgroup.listFiles()));
        }//w ww . j a  va 2s  .  com
    }
    System.out.printf("%d test files\n", files.size());
    final ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT");
    for (final File file : files) {
        final String ng = file.getParentFile().getName();

        final int actual = newsGroups.intern(ng);
        final SentimentModelHelper helper = new SentimentModelHelper();
        final Vector input = helper.encodeFeatureVector(file, overallCounts);// no
        // leak
        // type
        // ensures
        // this
        // is
        // a
        // normal
        // vector
        final Vector result = classifier.classifyFull(input);
        final int cat = result.maxValueIndex();
        final double score = result.maxValue();
        final double ll = classifier.logLikelihood(actual, input);
        final ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), score, ll);
        ra.addInstance(newsGroups.values().get(actual), cr);

    }
    output.printf("%s\n\n", ra.toString());
}

From source file:com.ml.ira.algos.RunLogistic.java

License:Apache License

static void mainToOutput(String[] args, PrintWriter output) throws Exception {
    if (parseArgs(args)) {
        if (!showAuc && !showConfusion && !showScores) {
            showAuc = true;/*from w w w .  j  a  v a  2s.  co m*/
            showConfusion = true;
        }

        Auc collector = new Auc();
        LogisticModelParameters lmp;
        if (modelFile.startsWith("hdfs://")) {
            lmp = LogisticModelParameters.loadFrom(new Path(modelFile));
        } else {
            lmp = LogisticModelParameters.loadFrom(new File(modelFile));
        }
        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        BufferedReader in = TrainLogistic.open(inputFile);
        //String line = in.readLine();
        //csv.firstLine(line);
        String line;
        if (fieldNames != null && fieldNames.equalsIgnoreCase("internal")) {
            csv.firstLine(lmp.getFieldNames());
        } else {
            csv.firstLine(in.readLine());
        }
        line = in.readLine();
        if (showScores) {
            output.println("\"target\",\"model-output\",\"log-likelihood\"");
        }
        while (line != null) {
            Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
            int target = csv.processLine(line, v);

            double score = lr.classifyScalar(v);
            if (showScores) {
                output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v));
            }
            collector.add(target, score);
            line = in.readLine();
        }

        if (showAuc) {
            output.printf(Locale.ENGLISH, "AUC = %.2f%n", collector.auc());
        }
        if (showConfusion) {
            Matrix m = collector.confusion();
            output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0), m.get(1, 0),
                    m.get(0, 1), m.get(1, 1));
            m = collector.entropy();
            output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0), m.get(1, 0),
                    m.get(0, 1), m.get(1, 1));
        }
    }
}

From source file:com.ml.ira.algos.TrainLogistic.java

License:Apache License

static void mainToOutput(String[] args, PrintWriter output) throws Exception {
    if (parseArgs(args)) {
        double logPEstimate = 0;
        int samples = 0;

        System.out.println("fieldNames: " + fieldNames);
        long ts = System.currentTimeMillis();
        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        for (int pass = 0; pass < passes; pass++) {
            System.out.println("at Round: " + pass);
            BufferedReader in = open(inputFile);
            try {
                // read variable names
                String line;/*from  w ww .j  a v a2  s.c  o  m*/
                if (fieldNames != null && fieldNames.length() > 0) {
                    csv.firstLine(fieldNames);
                } else {
                    csv.firstLine(in.readLine());
                }
                line = in.readLine();
                while (line != null) {
                    // for each new line, get target and predictors
                    Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                    int targetValue = csv.processLine(line, input);

                    // check performance while this is still news
                    double logP = lr.logLikelihood(targetValue, input);
                    if (!Double.isInfinite(logP)) {
                        if (samples < 20) {
                            logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
                        } else {
                            logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
                        }
                        samples++;
                    }
                    double p = lr.classifyScalar(input);
                    if (scores) {
                        output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n", samples,
                                targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
                    }

                    // now update model
                    lr.train(targetValue, input);

                    line = in.readLine();
                }
            } finally {
                Closeables.close(in, true);
            }
            output.println("duration: " + (System.currentTimeMillis() - ts));
        }

        if (outputFile.startsWith("hdfs://")) {
            lmp.saveTo(new Path(outputFile));
        } else {
            OutputStream modelOutput = new FileOutputStream(outputFile);
            try {
                lmp.saveTo(modelOutput);
            } finally {
                Closeables.close(modelOutput, false);
            }
        }

        output.println("duration: " + (System.currentTimeMillis() - ts));

        output.println(lmp.getNumFeatures());
        output.println(lmp.getTargetVariable() + " ~ ");
        String sep = "";
        for (String v : csv.getTraceDictionary().keySet()) {
            double weight = predictorWeight(lr, 0, csv, v);
            if (weight != 0) {
                output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
                sep = " + ";
            }
        }
        output.printf("%n");
        model = lr;
        for (int row = 0; row < lr.getBeta().numRows(); row++) {
            for (String key : csv.getTraceDictionary().keySet()) {
                double weight = predictorWeight(lr, row, csv, key);
                if (weight != 0) {
                    output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
                }
            }
            for (int column = 0; column < lr.getBeta().numCols(); column++) {
                output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
            }
            output.println();
        }
    }
}

From source file:haflow.component.mahout.logistic.RunLogistic.java

License:Apache License

static void mainToOutput(String[] args) throws Exception {
    if (parseArgs(args)) {
        if (!showAuc && !showConfusion && !showScores) {
            showAuc = true;//from   w w w .  jav a 2 s  .c  o m
            showConfusion = true;
        }

        //PrintWriter output=new PrintWriter(new FileOutputStream(outputFile),true);

        PrintWriter output = new PrintWriter(HdfsUtil.writeHdfs(outputFile), true);
        PrintWriter acc_output = new PrintWriter(HdfsUtil.writeHdfs(accurateFile), true);
        Auc collector = new Auc();
        LogisticModelParameters lmp = LogisticModelParameters.loadFrom(HdfsUtil.open(modelFile));

        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        BufferedReader in = new BufferedReader(new InputStreamReader(HdfsUtil.open(inputFile)));
        String line = in.readLine();
        csv.firstLine(line);
        line = in.readLine();
        if (showScores) {
            output.println("\"target\",\"model-output\",\"log-likelihood\"");
        }
        while (line != null) {
            Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
            int target = csv.processLine(line, v);

            double score = lr.classifyScalar(v);
            if (showScores) {
                output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v));
            }
            collector.add(target, score);
            line = in.readLine();
        }

        if (showAuc) {
            acc_output.printf(Locale.ENGLISH, "AUC , %.2f%n", collector.auc());
        }
        if (showConfusion) {
            Matrix m = collector.confusion();
            acc_output.printf(Locale.ENGLISH, "confusion, [[%.1f  %.1f], [%.1f  %.1f]]%n", m.get(0, 0),
                    m.get(1, 0), m.get(0, 1), m.get(1, 1));
            m = collector.entropy();
            acc_output.printf(Locale.ENGLISH, "entropy, [[%.1f  %.1f], [%.1f  %.1f]]%n", m.get(0, 0),
                    m.get(1, 0), m.get(0, 1), m.get(1, 1));
        }
        output.close();
        acc_output.close();
    }
}

From source file:haflow.component.mahout.logistic.TrainLogistic.java

License:Apache License

static void mainToOutput(String[] args) throws Exception {
    if (parseArgs(args)) {

        double logPEstimate = 0;
        int samples = 0;

        OutputStream o = HdfsUtil.writeHdfs(inforFile);
        PrintWriter output = new PrintWriter(o, true);

        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        for (int pass = 0; pass < passes; pass++) {
            BufferedReader in = new BufferedReader(new InputStreamReader(HdfsUtil.open(inputFile)));
            try {
                // read variable names
                csv.firstLine(in.readLine());

                String line = in.readLine();

                while (line != null) {
                    // for each new line, get target and predictors
                    Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                    int targetValue = csv.processLine(line, input);

                    // check performance while this is still news
                    double logP = lr.logLikelihood(targetValue, input);
                    if (!Double.isInfinite(logP)) {
                        if (samples < 20) {
                            logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
                        } else {
                            logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
                        }/*from   ww  w.j a v a  2  s.c o  m*/
                        samples++;
                    }
                    double p = lr.classifyScalar(input);
                    if (scores) {
                        output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n", samples,
                                targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
                    }

                    // now update model
                    lr.train(targetValue, input);

                    line = in.readLine();
                }
            } finally {
                Closeables.close(in, true);
            }
        }

        //OutputStream modelOutput = new FileOutputStream(outputFile);
        OutputStream modelOutput = HdfsUtil.writeHdfs(outputFile);
        try {
            lmp.saveTo(modelOutput);
        } finally {
            Closeables.close(modelOutput, false);
        }

        output.println(lmp.getNumFeatures());
        output.println(lmp.getTargetVariable() + " ~ ");
        String sep = "";
        for (String v : csv.getTraceDictionary().keySet()) {
            double weight = predictorWeight(lr, 0, csv, v);
            if (weight != 0) {
                output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
                sep = " + ";
            }
        }
        output.printf("%n");
        model = lr;
        for (int row = 0; row < lr.getBeta().numRows(); row++) {
            for (String key : csv.getTraceDictionary().keySet()) {
                double weight = predictorWeight(lr, row, csv, key);
                if (weight != 0) {
                    output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
                }
            }
            for (int column = 0; column < lr.getBeta().numCols(); column++) {
                output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
            }
            output.println();
        }
        output.close();
    }

}