Example usage for org.apache.mahout.classifier.sgd OnlineLogisticRegression classifyScalar

List of usage examples for org.apache.mahout.classifier.sgd OnlineLogisticRegression classifyScalar

Introduction

In this page you can find the example usage for org.apache.mahout.classifier.sgd OnlineLogisticRegression classifyScalar.

Prototype

@Override
public double classifyScalar(Vector instance) 

Source Link

Document

Returns a single scalar probability in the case where we have two categories.

Usage

From source file:TrainLogistic.java

License:Apache License

static void mainToOutput(String[] args, PrintWriter output) throws Exception {
    if (parseArgs(args)) {
        double logPEstimate = 0;
        int samples = 0;
        /*read files in dir of inputFile*/
        int fi = 0;//file ID
        File file = new File(inputFile);
        String[] fns = file.list(new FilenameFilter() {
            public boolean accept(File dir, String name) {
                if (name.endsWith(".svm")) {
                    return true;
                } else {
                    return false;
                }//w  w  w. j av a  2  s  .  c  o m
            }
        });

        String[] ss = new String[lmp.getNumFeatures() + 1];
        String[] iv = new String[2];
        OnlineLogisticRegression lr = lmp.createRegression();
        while (fi < fns.length) {
            for (int pass = 0; pass < passes; pass++) {
                BufferedReader in = open(inputFile + fns[fi]);
                System.out.println(pass + 1);
                try {
                    // read variable names

                    String line = in.readLine();
                    int lineCount = 1;
                    while (line != null) {
                        // for each new line, get target and predictors
                        Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                        ss = line.split(" ");
                        int targetValue;
                        if (ss[0].startsWith("+"))
                            targetValue = 1;
                        else
                            targetValue = 0;
                        int k = 1;
                        while (k < ss.length) {
                            iv = ss[k].split(":");
                            input.setQuick(Integer.valueOf(iv[0]) - 1, Double.valueOf(iv[1]));
                            //System.out.printf("%d-----%d:%.4f====%d\n", k,Integer.valueOf(iv[0])-1,Double.valueOf(iv[1]),lineCount);
                            k++;
                        }
                        input.setQuick(lmp.getNumFeatures() - 1, 1);
                        // check performance while this is still news
                        double logP = lr.logLikelihood(targetValue, input);
                        if (!Double.isInfinite(logP)) {
                            if (samples < 20) {
                                logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
                            } else {
                                logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
                            }
                            samples++;
                        }
                        double p = lr.classifyScalar(input);
                        if (scores) {
                            output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f\n", samples,
                                    targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
                        }
                        // now update model
                        lr.train(targetValue, input);
                        if ((lineCount) % 1000 == 0)
                            System.out.printf("%d\t", lineCount);
                        line = in.readLine();
                        lineCount++;
                    }
                } finally {
                    Closeables.closeQuietly(in);
                }
                System.out.println();
            }
            fi++;
        }

        FileOutputStream modelOutput = new FileOutputStream(outputFile);
        try {
            saveTo(modelOutput, lr);
        } finally {
            Closeables.closeQuietly(modelOutput);
        }
        /*
              output.printf(Locale.ENGLISH, "%d\n", lmp.getNumFeatures());
              output.printf(Locale.ENGLISH, "%s ~ ", lmp.getTargetVariable());
              String sep = "";
              for (String v : csv.getTraceDictionary().keySet()) {
                double weight = predictorWeight(lr, 0, csv, v);
                if (weight != 0) {
                  output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
                  sep = " + ";
                }
              }
              output.printf("\n");
              model = lr;
              for (int row = 0; row < lr.getBeta().numRows(); row++) {
                for (String key : csv.getTraceDictionary().keySet()) {
                  double weight = predictorWeight(lr, row, csv, key);
                  if (weight != 0) {
                    output.printf(Locale.ENGLISH, "%20s %.5f\n", key, weight);
                  }
                }
                for (int column = 0; column < lr.getBeta().numCols(); column++) {
                  output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
                }
                output.println();
              }*/
    }
}

From source file:br.com.sitedoph.mahout_examples.BankMarketingClassificationMain.java

License:Apache License

public static void main(String[] args) throws Exception {
    List<TelephoneCall> calls = Lists.newArrayList(new TelephoneCallParser("bank-full.csv"));

    double heldOutPercentage = 0.10;

    double biggestScore = 0.0;

    for (int run = 0; run < 20; run++) {
        Collections.shuffle(calls);
        int cutoff = (int) (heldOutPercentage * calls.size());
        List<TelephoneCall> testAccuracyData = calls.subList(0, cutoff);
        List<TelephoneCall> trainData = calls.subList(cutoff, calls.size());

        List<TelephoneCall> testUnknownData = new ArrayList<>();

        testUnknownData.add(getUnknownTelephoneCall(trainData));

        OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, TelephoneCall.FEATURES,
                new L1()).learningRate(1).alpha(1).lambda(0.000001).stepOffset(10000).decayExponent(0.2);

        for (int pass = 0; pass < 20; pass++) {
            for (TelephoneCall observation : trainData) {
                lr.train(observation.getTarget(), observation.asVector());
            }/*from   w w  w . j  ava  2  s .  c  o m*/
            Auc eval = new Auc(0.5);
            for (TelephoneCall testCall : testAccuracyData) {
                biggestScore = evaluateTheCallAndGetBiggestScore(biggestScore, lr, eval, testCall);
            }
            System.out.printf("run: %-5d pass: %-5d current learning rate: %-5.4f \teval auc %-5.4f\n", run,
                    pass, lr.currentLearningRate(), eval.auc());

            for (TelephoneCall testCall : testUnknownData) {
                final double score = lr.classifyScalar(testCall.asVector());
                System.out.println(" score: " + score + " accuracy " + eval.auc() + " call fields: "
                        + testCall.getFields());
            }
        }
    }
}

From source file:br.com.sitedoph.mahout_examples.BankMarketingClassificationMain.java

License:Apache License

private static double evaluateTheCallAndGetBiggestScore(double biggestScore, OnlineLogisticRegression lr,
        Auc eval, TelephoneCall call) {/*from   w w w  . j  a  v a 2  s .co m*/
    final double score = lr.classifyScalar(call.asVector());
    eval.add(call.getTarget(), score);
    if (score > biggestScore) {
        System.out.println("### SCORE > BIGGESTSCORE ### score: " + score + " accuracy " + eval.auc()
                + " call fields: " + call.getFields());
        biggestScore = score;
    }
    return biggestScore;
}

From source file:com.ml.ira.algos.RunLogistic.java

License:Apache License

static void mainToOutput(String[] args, PrintWriter output) throws Exception {
    if (parseArgs(args)) {
        if (!showAuc && !showConfusion && !showScores) {
            showAuc = true;/*from  w w  w. j av a2s  .co m*/
            showConfusion = true;
        }

        Auc collector = new Auc();
        LogisticModelParameters lmp;
        if (modelFile.startsWith("hdfs://")) {
            lmp = LogisticModelParameters.loadFrom(new Path(modelFile));
        } else {
            lmp = LogisticModelParameters.loadFrom(new File(modelFile));
        }
        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        BufferedReader in = TrainLogistic.open(inputFile);
        //String line = in.readLine();
        //csv.firstLine(line);
        String line;
        if (fieldNames != null && fieldNames.equalsIgnoreCase("internal")) {
            csv.firstLine(lmp.getFieldNames());
        } else {
            csv.firstLine(in.readLine());
        }
        line = in.readLine();
        if (showScores) {
            output.println("\"target\",\"model-output\",\"log-likelihood\"");
        }
        while (line != null) {
            Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
            int target = csv.processLine(line, v);

            double score = lr.classifyScalar(v);
            if (showScores) {
                output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v));
            }
            collector.add(target, score);
            line = in.readLine();
        }

        if (showAuc) {
            output.printf(Locale.ENGLISH, "AUC = %.2f%n", collector.auc());
        }
        if (showConfusion) {
            Matrix m = collector.confusion();
            output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0), m.get(1, 0),
                    m.get(0, 1), m.get(1, 1));
            m = collector.entropy();
            output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0), m.get(1, 0),
                    m.get(0, 1), m.get(1, 1));
        }
    }
}

From source file:com.ml.ira.algos.TrainLogistic.java

License:Apache License

static void mainToOutput(String[] args, PrintWriter output) throws Exception {
    if (parseArgs(args)) {
        double logPEstimate = 0;
        int samples = 0;

        System.out.println("fieldNames: " + fieldNames);
        long ts = System.currentTimeMillis();
        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        for (int pass = 0; pass < passes; pass++) {
            System.out.println("at Round: " + pass);
            BufferedReader in = open(inputFile);
            try {
                // read variable names
                String line;/*from   w  ww .  jav  a2 s  .  c  o m*/
                if (fieldNames != null && fieldNames.length() > 0) {
                    csv.firstLine(fieldNames);
                } else {
                    csv.firstLine(in.readLine());
                }
                line = in.readLine();
                while (line != null) {
                    // for each new line, get target and predictors
                    Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                    int targetValue = csv.processLine(line, input);

                    // check performance while this is still news
                    double logP = lr.logLikelihood(targetValue, input);
                    if (!Double.isInfinite(logP)) {
                        if (samples < 20) {
                            logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
                        } else {
                            logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
                        }
                        samples++;
                    }
                    double p = lr.classifyScalar(input);
                    if (scores) {
                        output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n", samples,
                                targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
                    }

                    // now update model
                    lr.train(targetValue, input);

                    line = in.readLine();
                }
            } finally {
                Closeables.close(in, true);
            }
            output.println("duration: " + (System.currentTimeMillis() - ts));
        }

        if (outputFile.startsWith("hdfs://")) {
            lmp.saveTo(new Path(outputFile));
        } else {
            OutputStream modelOutput = new FileOutputStream(outputFile);
            try {
                lmp.saveTo(modelOutput);
            } finally {
                Closeables.close(modelOutput, false);
            }
        }

        output.println("duration: " + (System.currentTimeMillis() - ts));

        output.println(lmp.getNumFeatures());
        output.println(lmp.getTargetVariable() + " ~ ");
        String sep = "";
        for (String v : csv.getTraceDictionary().keySet()) {
            double weight = predictorWeight(lr, 0, csv, v);
            if (weight != 0) {
                output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
                sep = " + ";
            }
        }
        output.printf("%n");
        model = lr;
        for (int row = 0; row < lr.getBeta().numRows(); row++) {
            for (String key : csv.getTraceDictionary().keySet()) {
                double weight = predictorWeight(lr, row, csv, key);
                if (weight != 0) {
                    output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
                }
            }
            for (int column = 0; column < lr.getBeta().numCols(); column++) {
                output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
            }
            output.println();
        }
    }
}

From source file:haflow.component.mahout.logistic.RunLogistic.java

License:Apache License

static void mainToOutput(String[] args) throws Exception {
    if (parseArgs(args)) {
        if (!showAuc && !showConfusion && !showScores) {
            showAuc = true;/*ww  w .ja  va2s.co m*/
            showConfusion = true;
        }

        //PrintWriter output=new PrintWriter(new FileOutputStream(outputFile),true);

        PrintWriter output = new PrintWriter(HdfsUtil.writeHdfs(outputFile), true);
        PrintWriter acc_output = new PrintWriter(HdfsUtil.writeHdfs(accurateFile), true);
        Auc collector = new Auc();
        LogisticModelParameters lmp = LogisticModelParameters.loadFrom(HdfsUtil.open(modelFile));

        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        BufferedReader in = new BufferedReader(new InputStreamReader(HdfsUtil.open(inputFile)));
        String line = in.readLine();
        csv.firstLine(line);
        line = in.readLine();
        if (showScores) {
            output.println("\"target\",\"model-output\",\"log-likelihood\"");
        }
        while (line != null) {
            Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
            int target = csv.processLine(line, v);

            double score = lr.classifyScalar(v);
            if (showScores) {
                output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v));
            }
            collector.add(target, score);
            line = in.readLine();
        }

        if (showAuc) {
            acc_output.printf(Locale.ENGLISH, "AUC , %.2f%n", collector.auc());
        }
        if (showConfusion) {
            Matrix m = collector.confusion();
            acc_output.printf(Locale.ENGLISH, "confusion, [[%.1f  %.1f], [%.1f  %.1f]]%n", m.get(0, 0),
                    m.get(1, 0), m.get(0, 1), m.get(1, 1));
            m = collector.entropy();
            acc_output.printf(Locale.ENGLISH, "entropy, [[%.1f  %.1f], [%.1f  %.1f]]%n", m.get(0, 0),
                    m.get(1, 0), m.get(0, 1), m.get(1, 1));
        }
        output.close();
        acc_output.close();
    }
}

From source file:haflow.component.mahout.logistic.TrainLogistic.java

License:Apache License

static void mainToOutput(String[] args) throws Exception {
    if (parseArgs(args)) {

        double logPEstimate = 0;
        int samples = 0;

        OutputStream o = HdfsUtil.writeHdfs(inforFile);
        PrintWriter output = new PrintWriter(o, true);

        CsvRecordFactory csv = lmp.getCsvRecordFactory();
        OnlineLogisticRegression lr = lmp.createRegression();
        for (int pass = 0; pass < passes; pass++) {
            BufferedReader in = new BufferedReader(new InputStreamReader(HdfsUtil.open(inputFile)));
            try {
                // read variable names
                csv.firstLine(in.readLine());

                String line = in.readLine();

                while (line != null) {
                    // for each new line, get target and predictors
                    Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                    int targetValue = csv.processLine(line, input);

                    // check performance while this is still news
                    double logP = lr.logLikelihood(targetValue, input);
                    if (!Double.isInfinite(logP)) {
                        if (samples < 20) {
                            logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
                        } else {
                            logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
                        }/*from   w  ww. ja v a2  s.c o  m*/
                        samples++;
                    }
                    double p = lr.classifyScalar(input);
                    if (scores) {
                        output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n", samples,
                                targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
                    }

                    // now update model
                    lr.train(targetValue, input);

                    line = in.readLine();
                }
            } finally {
                Closeables.close(in, true);
            }
        }

        //OutputStream modelOutput = new FileOutputStream(outputFile);
        OutputStream modelOutput = HdfsUtil.writeHdfs(outputFile);
        try {
            lmp.saveTo(modelOutput);
        } finally {
            Closeables.close(modelOutput, false);
        }

        output.println(lmp.getNumFeatures());
        output.println(lmp.getTargetVariable() + " ~ ");
        String sep = "";
        for (String v : csv.getTraceDictionary().keySet()) {
            double weight = predictorWeight(lr, 0, csv, v);
            if (weight != 0) {
                output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
                sep = " + ";
            }
        }
        output.printf("%n");
        model = lr;
        for (int row = 0; row < lr.getBeta().numRows(); row++) {
            for (String key : csv.getTraceDictionary().keySet()) {
                double weight = predictorWeight(lr, row, csv, key);
                if (weight != 0) {
                    output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
                }
            }
            for (int column = 0; column < lr.getBeta().numCols(); column++) {
                output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
            }
            output.println();
        }
        output.close();
    }

}

From source file:OpioidePrescriberClassification.Driver.java

public static void main(String args[]) throws Exception {
    List<Opioides> calls = Lists.newArrayList(new Parser("/input1/try.csv"));
    double heldOutPercentage = 0.10;
    //        for (int run = 0; run < 20; run++) 
    {// w w w .  j av a  2  s . c  o m
        //            Random random = RandomUtils.getRandom();
        Collections.shuffle(calls);
        int cutoff = (int) (heldOutPercentage * calls.size());
        List<Opioides> test = calls.subList(0, cutoff);
        List<Opioides> train = calls.subList(cutoff, calls.size());

        OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, Opioides.FEATURES, new L1())
                .learningRate(1).alpha(1).lambda(0.000001).stepOffset(10000).decayExponent(0.2);

        //            for (int pass = 0; pass < 2 ; pass++)
        {
            System.err.println("pass");
            for (Opioides observation : train) {
                lr.train(observation.getTarget(), observation.asVector());
            }
            //                if (pass % 2 == 0) 
            {
                Auc eval = new Auc(0.5);
                for (Opioides testCall : test) {
                    eval.add(testCall.getTarget(), lr.classifyScalar(testCall.asVector()));
                }
                System.out.printf("%d, %.4f, %.4f\n", 1, lr.currentLearningRate(), eval.auc());
            }
        }
    }
}