List of usage examples for org.apache.mahout.classifier.sgd OnlineLogisticRegression classifyScalar
@Override public double classifyScalar(Vector instance)
From source file:TrainLogistic.java
License:Apache License
static void mainToOutput(String[] args, PrintWriter output) throws Exception { if (parseArgs(args)) { double logPEstimate = 0; int samples = 0; /*read files in dir of inputFile*/ int fi = 0;//file ID File file = new File(inputFile); String[] fns = file.list(new FilenameFilter() { public boolean accept(File dir, String name) { if (name.endsWith(".svm")) { return true; } else { return false; }//w w w. j av a 2 s . c o m } }); String[] ss = new String[lmp.getNumFeatures() + 1]; String[] iv = new String[2]; OnlineLogisticRegression lr = lmp.createRegression(); while (fi < fns.length) { for (int pass = 0; pass < passes; pass++) { BufferedReader in = open(inputFile + fns[fi]); System.out.println(pass + 1); try { // read variable names String line = in.readLine(); int lineCount = 1; while (line != null) { // for each new line, get target and predictors Vector input = new RandomAccessSparseVector(lmp.getNumFeatures()); ss = line.split(" "); int targetValue; if (ss[0].startsWith("+")) targetValue = 1; else targetValue = 0; int k = 1; while (k < ss.length) { iv = ss[k].split(":"); input.setQuick(Integer.valueOf(iv[0]) - 1, Double.valueOf(iv[1])); //System.out.printf("%d-----%d:%.4f====%d\n", k,Integer.valueOf(iv[0])-1,Double.valueOf(iv[1]),lineCount); k++; } input.setQuick(lmp.getNumFeatures() - 1, 1); // check performance while this is still news double logP = lr.logLikelihood(targetValue, input); if (!Double.isInfinite(logP)) { if (samples < 20) { logPEstimate = (samples * logPEstimate + logP) / (samples + 1); } else { logPEstimate = 0.95 * logPEstimate + 0.05 * logP; } samples++; } double p = lr.classifyScalar(input); if (scores) { output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f\n", samples, targetValue, lr.currentLearningRate(), p, logP, logPEstimate); } // now update model lr.train(targetValue, input); if ((lineCount) % 1000 == 0) System.out.printf("%d\t", lineCount); line = in.readLine(); lineCount++; } } finally { Closeables.closeQuietly(in); } System.out.println(); } fi++; } FileOutputStream modelOutput = new FileOutputStream(outputFile); try { saveTo(modelOutput, lr); } finally { Closeables.closeQuietly(modelOutput); } /* output.printf(Locale.ENGLISH, "%d\n", lmp.getNumFeatures()); output.printf(Locale.ENGLISH, "%s ~ ", lmp.getTargetVariable()); String sep = ""; for (String v : csv.getTraceDictionary().keySet()) { double weight = predictorWeight(lr, 0, csv, v); if (weight != 0) { output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v); sep = " + "; } } output.printf("\n"); model = lr; for (int row = 0; row < lr.getBeta().numRows(); row++) { for (String key : csv.getTraceDictionary().keySet()) { double weight = predictorWeight(lr, row, csv, key); if (weight != 0) { output.printf(Locale.ENGLISH, "%20s %.5f\n", key, weight); } } for (int column = 0; column < lr.getBeta().numCols(); column++) { output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column)); } output.println(); }*/ } }
From source file:br.com.sitedoph.mahout_examples.BankMarketingClassificationMain.java
License:Apache License
public static void main(String[] args) throws Exception { List<TelephoneCall> calls = Lists.newArrayList(new TelephoneCallParser("bank-full.csv")); double heldOutPercentage = 0.10; double biggestScore = 0.0; for (int run = 0; run < 20; run++) { Collections.shuffle(calls); int cutoff = (int) (heldOutPercentage * calls.size()); List<TelephoneCall> testAccuracyData = calls.subList(0, cutoff); List<TelephoneCall> trainData = calls.subList(cutoff, calls.size()); List<TelephoneCall> testUnknownData = new ArrayList<>(); testUnknownData.add(getUnknownTelephoneCall(trainData)); OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, TelephoneCall.FEATURES, new L1()).learningRate(1).alpha(1).lambda(0.000001).stepOffset(10000).decayExponent(0.2); for (int pass = 0; pass < 20; pass++) { for (TelephoneCall observation : trainData) { lr.train(observation.getTarget(), observation.asVector()); }/*from w w w . j ava 2 s . c o m*/ Auc eval = new Auc(0.5); for (TelephoneCall testCall : testAccuracyData) { biggestScore = evaluateTheCallAndGetBiggestScore(biggestScore, lr, eval, testCall); } System.out.printf("run: %-5d pass: %-5d current learning rate: %-5.4f \teval auc %-5.4f\n", run, pass, lr.currentLearningRate(), eval.auc()); for (TelephoneCall testCall : testUnknownData) { final double score = lr.classifyScalar(testCall.asVector()); System.out.println(" score: " + score + " accuracy " + eval.auc() + " call fields: " + testCall.getFields()); } } } }
From source file:br.com.sitedoph.mahout_examples.BankMarketingClassificationMain.java
License:Apache License
private static double evaluateTheCallAndGetBiggestScore(double biggestScore, OnlineLogisticRegression lr, Auc eval, TelephoneCall call) {/*from w w w . j a v a 2 s .co m*/ final double score = lr.classifyScalar(call.asVector()); eval.add(call.getTarget(), score); if (score > biggestScore) { System.out.println("### SCORE > BIGGESTSCORE ### score: " + score + " accuracy " + eval.auc() + " call fields: " + call.getFields()); biggestScore = score; } return biggestScore; }
From source file:com.ml.ira.algos.RunLogistic.java
License:Apache License
static void mainToOutput(String[] args, PrintWriter output) throws Exception { if (parseArgs(args)) { if (!showAuc && !showConfusion && !showScores) { showAuc = true;/*from w w w. j av a2s .co m*/ showConfusion = true; } Auc collector = new Auc(); LogisticModelParameters lmp; if (modelFile.startsWith("hdfs://")) { lmp = LogisticModelParameters.loadFrom(new Path(modelFile)); } else { lmp = LogisticModelParameters.loadFrom(new File(modelFile)); } CsvRecordFactory csv = lmp.getCsvRecordFactory(); OnlineLogisticRegression lr = lmp.createRegression(); BufferedReader in = TrainLogistic.open(inputFile); //String line = in.readLine(); //csv.firstLine(line); String line; if (fieldNames != null && fieldNames.equalsIgnoreCase("internal")) { csv.firstLine(lmp.getFieldNames()); } else { csv.firstLine(in.readLine()); } line = in.readLine(); if (showScores) { output.println("\"target\",\"model-output\",\"log-likelihood\""); } while (line != null) { Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); int target = csv.processLine(line, v); double score = lr.classifyScalar(v); if (showScores) { output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v)); } collector.add(target, score); line = in.readLine(); } if (showAuc) { output.printf(Locale.ENGLISH, "AUC = %.2f%n", collector.auc()); } if (showConfusion) { Matrix m = collector.confusion(); output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); m = collector.entropy(); output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); } } }
From source file:com.ml.ira.algos.TrainLogistic.java
License:Apache License
static void mainToOutput(String[] args, PrintWriter output) throws Exception { if (parseArgs(args)) { double logPEstimate = 0; int samples = 0; System.out.println("fieldNames: " + fieldNames); long ts = System.currentTimeMillis(); CsvRecordFactory csv = lmp.getCsvRecordFactory(); OnlineLogisticRegression lr = lmp.createRegression(); for (int pass = 0; pass < passes; pass++) { System.out.println("at Round: " + pass); BufferedReader in = open(inputFile); try { // read variable names String line;/*from w ww . jav a2 s . c o m*/ if (fieldNames != null && fieldNames.length() > 0) { csv.firstLine(fieldNames); } else { csv.firstLine(in.readLine()); } line = in.readLine(); while (line != null) { // for each new line, get target and predictors Vector input = new RandomAccessSparseVector(lmp.getNumFeatures()); int targetValue = csv.processLine(line, input); // check performance while this is still news double logP = lr.logLikelihood(targetValue, input); if (!Double.isInfinite(logP)) { if (samples < 20) { logPEstimate = (samples * logPEstimate + logP) / (samples + 1); } else { logPEstimate = 0.95 * logPEstimate + 0.05 * logP; } samples++; } double p = lr.classifyScalar(input); if (scores) { output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n", samples, targetValue, lr.currentLearningRate(), p, logP, logPEstimate); } // now update model lr.train(targetValue, input); line = in.readLine(); } } finally { Closeables.close(in, true); } output.println("duration: " + (System.currentTimeMillis() - ts)); } if (outputFile.startsWith("hdfs://")) { lmp.saveTo(new Path(outputFile)); } else { OutputStream modelOutput = new FileOutputStream(outputFile); try { lmp.saveTo(modelOutput); } finally { Closeables.close(modelOutput, false); } } output.println("duration: " + (System.currentTimeMillis() - ts)); output.println(lmp.getNumFeatures()); output.println(lmp.getTargetVariable() + " ~ "); String sep = ""; for (String v : csv.getTraceDictionary().keySet()) { double weight = predictorWeight(lr, 0, csv, v); if (weight != 0) { output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v); sep = " + "; } } output.printf("%n"); model = lr; for (int row = 0; row < lr.getBeta().numRows(); row++) { for (String key : csv.getTraceDictionary().keySet()) { double weight = predictorWeight(lr, row, csv, key); if (weight != 0) { output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight); } } for (int column = 0; column < lr.getBeta().numCols(); column++) { output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column)); } output.println(); } } }
From source file:haflow.component.mahout.logistic.RunLogistic.java
License:Apache License
static void mainToOutput(String[] args) throws Exception { if (parseArgs(args)) { if (!showAuc && !showConfusion && !showScores) { showAuc = true;/*ww w .ja va2s.co m*/ showConfusion = true; } //PrintWriter output=new PrintWriter(new FileOutputStream(outputFile),true); PrintWriter output = new PrintWriter(HdfsUtil.writeHdfs(outputFile), true); PrintWriter acc_output = new PrintWriter(HdfsUtil.writeHdfs(accurateFile), true); Auc collector = new Auc(); LogisticModelParameters lmp = LogisticModelParameters.loadFrom(HdfsUtil.open(modelFile)); CsvRecordFactory csv = lmp.getCsvRecordFactory(); OnlineLogisticRegression lr = lmp.createRegression(); BufferedReader in = new BufferedReader(new InputStreamReader(HdfsUtil.open(inputFile))); String line = in.readLine(); csv.firstLine(line); line = in.readLine(); if (showScores) { output.println("\"target\",\"model-output\",\"log-likelihood\""); } while (line != null) { Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); int target = csv.processLine(line, v); double score = lr.classifyScalar(v); if (showScores) { output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v)); } collector.add(target, score); line = in.readLine(); } if (showAuc) { acc_output.printf(Locale.ENGLISH, "AUC , %.2f%n", collector.auc()); } if (showConfusion) { Matrix m = collector.confusion(); acc_output.printf(Locale.ENGLISH, "confusion, [[%.1f %.1f], [%.1f %.1f]]%n", m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); m = collector.entropy(); acc_output.printf(Locale.ENGLISH, "entropy, [[%.1f %.1f], [%.1f %.1f]]%n", m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); } output.close(); acc_output.close(); } }
From source file:haflow.component.mahout.logistic.TrainLogistic.java
License:Apache License
static void mainToOutput(String[] args) throws Exception { if (parseArgs(args)) { double logPEstimate = 0; int samples = 0; OutputStream o = HdfsUtil.writeHdfs(inforFile); PrintWriter output = new PrintWriter(o, true); CsvRecordFactory csv = lmp.getCsvRecordFactory(); OnlineLogisticRegression lr = lmp.createRegression(); for (int pass = 0; pass < passes; pass++) { BufferedReader in = new BufferedReader(new InputStreamReader(HdfsUtil.open(inputFile))); try { // read variable names csv.firstLine(in.readLine()); String line = in.readLine(); while (line != null) { // for each new line, get target and predictors Vector input = new RandomAccessSparseVector(lmp.getNumFeatures()); int targetValue = csv.processLine(line, input); // check performance while this is still news double logP = lr.logLikelihood(targetValue, input); if (!Double.isInfinite(logP)) { if (samples < 20) { logPEstimate = (samples * logPEstimate + logP) / (samples + 1); } else { logPEstimate = 0.95 * logPEstimate + 0.05 * logP; }/*from w ww. ja v a2 s.c o m*/ samples++; } double p = lr.classifyScalar(input); if (scores) { output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n", samples, targetValue, lr.currentLearningRate(), p, logP, logPEstimate); } // now update model lr.train(targetValue, input); line = in.readLine(); } } finally { Closeables.close(in, true); } } //OutputStream modelOutput = new FileOutputStream(outputFile); OutputStream modelOutput = HdfsUtil.writeHdfs(outputFile); try { lmp.saveTo(modelOutput); } finally { Closeables.close(modelOutput, false); } output.println(lmp.getNumFeatures()); output.println(lmp.getTargetVariable() + " ~ "); String sep = ""; for (String v : csv.getTraceDictionary().keySet()) { double weight = predictorWeight(lr, 0, csv, v); if (weight != 0) { output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v); sep = " + "; } } output.printf("%n"); model = lr; for (int row = 0; row < lr.getBeta().numRows(); row++) { for (String key : csv.getTraceDictionary().keySet()) { double weight = predictorWeight(lr, row, csv, key); if (weight != 0) { output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight); } } for (int column = 0; column < lr.getBeta().numCols(); column++) { output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column)); } output.println(); } output.close(); } }
From source file:OpioidePrescriberClassification.Driver.java
public static void main(String args[]) throws Exception { List<Opioides> calls = Lists.newArrayList(new Parser("/input1/try.csv")); double heldOutPercentage = 0.10; // for (int run = 0; run < 20; run++) {// w w w . j av a 2 s . c o m // Random random = RandomUtils.getRandom(); Collections.shuffle(calls); int cutoff = (int) (heldOutPercentage * calls.size()); List<Opioides> test = calls.subList(0, cutoff); List<Opioides> train = calls.subList(cutoff, calls.size()); OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, Opioides.FEATURES, new L1()) .learningRate(1).alpha(1).lambda(0.000001).stepOffset(10000).decayExponent(0.2); // for (int pass = 0; pass < 2 ; pass++) { System.err.println("pass"); for (Opioides observation : train) { lr.train(observation.getTarget(), observation.asVector()); } // if (pass % 2 == 0) { Auc eval = new Auc(0.5); for (Opioides testCall : test) { eval.add(testCall.getTarget(), lr.classifyScalar(testCall.asVector())); } System.out.printf("%d, %.4f, %.4f\n", 1, lr.currentLearningRate(), eval.auc()); } } } }