List of usage examples for org.apache.mahout.classifier.sgd CsvRecordFactory processLine
@Override public int processLine(String line, Vector featureVector)
From source file:com.ml.ira.algos.RunLogistic.java
License:Apache License
static void mainToOutput(String[] args, PrintWriter output) throws Exception { if (parseArgs(args)) { if (!showAuc && !showConfusion && !showScores) { showAuc = true;/*from w w w .j ava 2s . c o m*/ showConfusion = true; } Auc collector = new Auc(); LogisticModelParameters lmp; if (modelFile.startsWith("hdfs://")) { lmp = LogisticModelParameters.loadFrom(new Path(modelFile)); } else { lmp = LogisticModelParameters.loadFrom(new File(modelFile)); } CsvRecordFactory csv = lmp.getCsvRecordFactory(); OnlineLogisticRegression lr = lmp.createRegression(); BufferedReader in = TrainLogistic.open(inputFile); //String line = in.readLine(); //csv.firstLine(line); String line; if (fieldNames != null && fieldNames.equalsIgnoreCase("internal")) { csv.firstLine(lmp.getFieldNames()); } else { csv.firstLine(in.readLine()); } line = in.readLine(); if (showScores) { output.println("\"target\",\"model-output\",\"log-likelihood\""); } while (line != null) { Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); int target = csv.processLine(line, v); double score = lr.classifyScalar(v); if (showScores) { output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v)); } collector.add(target, score); line = in.readLine(); } if (showAuc) { output.printf(Locale.ENGLISH, "AUC = %.2f%n", collector.auc()); } if (showConfusion) { Matrix m = collector.confusion(); output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); m = collector.entropy(); output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); } } }
From source file:com.ml.ira.algos.TrainLogistic.java
License:Apache License
static void mainToOutput(String[] args, PrintWriter output) throws Exception { if (parseArgs(args)) { double logPEstimate = 0; int samples = 0; System.out.println("fieldNames: " + fieldNames); long ts = System.currentTimeMillis(); CsvRecordFactory csv = lmp.getCsvRecordFactory(); OnlineLogisticRegression lr = lmp.createRegression(); for (int pass = 0; pass < passes; pass++) { System.out.println("at Round: " + pass); BufferedReader in = open(inputFile); try { // read variable names String line;// w ww . j a v a2s .c o m if (fieldNames != null && fieldNames.length() > 0) { csv.firstLine(fieldNames); } else { csv.firstLine(in.readLine()); } line = in.readLine(); while (line != null) { // for each new line, get target and predictors Vector input = new RandomAccessSparseVector(lmp.getNumFeatures()); int targetValue = csv.processLine(line, input); // check performance while this is still news double logP = lr.logLikelihood(targetValue, input); if (!Double.isInfinite(logP)) { if (samples < 20) { logPEstimate = (samples * logPEstimate + logP) / (samples + 1); } else { logPEstimate = 0.95 * logPEstimate + 0.05 * logP; } samples++; } double p = lr.classifyScalar(input); if (scores) { output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n", samples, targetValue, lr.currentLearningRate(), p, logP, logPEstimate); } // now update model lr.train(targetValue, input); line = in.readLine(); } } finally { Closeables.close(in, true); } output.println("duration: " + (System.currentTimeMillis() - ts)); } if (outputFile.startsWith("hdfs://")) { lmp.saveTo(new Path(outputFile)); } else { OutputStream modelOutput = new FileOutputStream(outputFile); try { lmp.saveTo(modelOutput); } finally { Closeables.close(modelOutput, false); } } output.println("duration: " + (System.currentTimeMillis() - ts)); output.println(lmp.getNumFeatures()); output.println(lmp.getTargetVariable() + " ~ "); String sep = ""; for (String v : csv.getTraceDictionary().keySet()) { double weight = predictorWeight(lr, 0, csv, v); if (weight != 0) { output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v); sep = " + "; } } output.printf("%n"); model = lr; for (int row = 0; row < lr.getBeta().numRows(); row++) { for (String key : csv.getTraceDictionary().keySet()) { double weight = predictorWeight(lr, row, csv, key); if (weight != 0) { output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight); } } for (int column = 0; column < lr.getBeta().numCols(); column++) { output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column)); } output.println(); } } }
From source file:edu.isi.karma.cleaning.features.RecordClassifier2.java
License:Apache License
@SuppressWarnings({ "deprecation" })
public OnlineLogisticRegression train(HashMap<String, Vector<String>> traindata) throws Exception {
String csvTrainFile = "./target/tmp/csvtrain.csv";
Data2Features.Traindata2CSV(traindata, csvTrainFile, rf);
lmp = new LogisticModelParameters();
lmp.setTargetVariable("label");
lmp.setMaxTargetCategories(rf.labels.size());
lmp.setNumFeatures(rf.getFeatureNames().size());
List<String> typeList = Lists.newArrayList();
typeList.add("numeric");
List<String> predictorList = Lists.newArrayList();
for (String attr : rf.getFeatureNames()) {
if (attr.compareTo("lable") != 0) {
predictorList.add(attr);//from ww w. ja va 2 s .c o m
}
}
lmp.setTypeMap(predictorList, typeList);
// lmp.setUseBias(!getBooleanArgument(cmdLine, noBias));
// lmp.setTypeMap(predictorList, typeList);
lmp.setLambda(1e-4);
lmp.setLearningRate(50);
int passes = 100;
CsvRecordFactory csv = lmp.getCsvRecordFactory();
OnlineLogisticRegression lr = lmp.createRegression();
for (int pass = 0; pass < passes; pass++) {
BufferedReader in = new BufferedReader(new FileReader(new File(csvTrainFile)));
;
try {
// read variable names
csv.firstLine(in.readLine());
String line = in.readLine();
while (line != null) {
// for each new line, get target and predictors
RandomAccessSparseVector input = new RandomAccessSparseVector(lmp.getNumFeatures());
int targetValue = csv.processLine(line, input);
// String label =
// csv.getTargetCategories().get(lr.classifyFull(input).maxValueIndex());
// now update model
lr.train(targetValue, input);
line = in.readLine();
}
} finally {
Closeables.closeQuietly(in);
}
}
labels = csv.getTargetCategories();
return lr;
}
From source file:edu.isi.karma.cleaning.features.RecordClassifier2.java
License:Apache License
public String Classify(String instance) { Collection<Feature> cfeat = rf.computeFeatures(instance, ""); Feature[] x = cfeat.toArray(new Feature[cfeat.size()]); // row.add(f.getName()); RandomAccessSparseVector row = new RandomAccessSparseVector(x.length); String line = ""; for (int k = 0; k < cfeat.size(); k++) { line += x[k].getScore() + ","; }// w w w .j av a2 s .c o m line += "label"; // dummy class label for testing CsvRecordFactory csv = lmp.getCsvRecordFactory(); csv.processLine(line, row); DenseVector dvec = (DenseVector) this.cf.classifyFull(row); String label = labels.get(dvec.maxValueIndex()); return label; }
From source file:guipart.view.GUIOverviewController.java
@FXML void handleClassifyModel(ActionEvent event) throws IOException { if (pathModel != null && pathCSV != null) { Auc collector = new Auc(); LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(pathModel)); CsvRecordFactory csv = lmp.getCsvRecordFactory(); OnlineLogisticRegression lr = lmp.createRegression(); BufferedReader in = Utils.open(pathCSV); String line = in.readLine(); csv.firstLine(line);/*ww w . jav a 2s . com*/ line = in.readLine(); int correct = 0; int wrong = 0; Boolean booltemp; String gender; while (line != null) { Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); int target = csv.processLine(line, v); String[] split = line.split(","); double score = lr.classifyFull(v).maxValueIndex(); if (score == target) correct++; else wrong++; System.out.println("Target is: " + target + " Score: " + score); booltemp = score != 0; if (split[1].contentEquals("1")) gender = "male"; else gender = "female"; Person temp = new Person(Integer.parseInt(split[0]), Integer.parseInt(split[4]), Integer.parseInt(split[7]), booltemp, gender, Integer.parseInt(split[5]), Integer.parseInt(split[6]), Integer.parseInt(split[3])); guiPart.addPerson(temp); line = in.readLine(); collector.add(target, score); } double posto = ((double) wrong / (double) (correct + wrong)) * 100; System.out.println("Total: " + (correct + wrong) + " Correct: " + correct + " Wrong: " + wrong + " Wrong pct: " + posto + "%"); //PrintWriter output = null; Matrix m = collector.confusion(); //output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); System.out.println("Confusion:" + m.get(0, 0) + " " + m.get(1, 0) + "\n \t " + m.get(0, 1) + " " + m.get(1, 1) + " "); // m = collector.entropy(); //output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); textAnalyze2.setText("Confusion:" + m.get(0, 0) + " " + m.get(1, 0) + "\n \t \t " + m.get(0, 1) + " " + m.get(1, 1) + "\n" + "Total: " + (correct + wrong) + " Correct: " + correct + " Wrong: " + wrong + " Wrong pct: " + posto + "%"); } else { Dialogs.create().owner(guiPart.getPrimaryStage()).title("Error Dialog") .masthead("Look, an Error Dialog").message("One or more files aren't selected").showError(); } }
From source file:guipart.view.GUIOverviewController.java
@FXML void singlClassify(ActionEvent e) throws IOException { LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(pathModel)); CsvRecordFactory csv = lmp.getCsvRecordFactory(); OnlineLogisticRegression lr = lmp.createRegression(); csv.firstLine("custID,gender,state,cardholder,balance,numTrans,numIntlTrans,creditLine,fraudRisk"); String line;//from w w w.j av a 2 s .c om line = scID.getText(); line = line.concat("," + scGender.getText()); line = line.concat("," + scState.getText()); line = line.concat("," + scCardholders.getText()); line = line.concat("," + scBalance.getText()); line = line.concat("," + scTrans.getText()); line = line.concat("," + scIntlTrans.getText()); line = line.concat("," + scCreditLine.getText()); line = line.concat(",0 \n"); Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); int target = csv.processLine(line, v); String[] split = line.split(","); double score = lr.classifyFull(v).maxValueIndex(); boolean booltemp = score != 0; String gender; if (split[1].contentEquals("1")) gender = "male"; else gender = "female"; Person temp = new Person(Integer.parseInt(split[0]), Integer.parseInt(split[4]), Integer.parseInt(split[7]), booltemp, gender, Integer.parseInt(split[5]), Integer.parseInt(split[6]), Integer.parseInt(split[3])); guiPart.addPerson(temp); }
From source file:haflow.component.mahout.logistic.RunLogistic.java
License:Apache License
static void mainToOutput(String[] args) throws Exception { if (parseArgs(args)) { if (!showAuc && !showConfusion && !showScores) { showAuc = true;//from ww w . j a v a2s . c o m showConfusion = true; } //PrintWriter output=new PrintWriter(new FileOutputStream(outputFile),true); PrintWriter output = new PrintWriter(HdfsUtil.writeHdfs(outputFile), true); PrintWriter acc_output = new PrintWriter(HdfsUtil.writeHdfs(accurateFile), true); Auc collector = new Auc(); LogisticModelParameters lmp = LogisticModelParameters.loadFrom(HdfsUtil.open(modelFile)); CsvRecordFactory csv = lmp.getCsvRecordFactory(); OnlineLogisticRegression lr = lmp.createRegression(); BufferedReader in = new BufferedReader(new InputStreamReader(HdfsUtil.open(inputFile))); String line = in.readLine(); csv.firstLine(line); line = in.readLine(); if (showScores) { output.println("\"target\",\"model-output\",\"log-likelihood\""); } while (line != null) { Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); int target = csv.processLine(line, v); double score = lr.classifyScalar(v); if (showScores) { output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v)); } collector.add(target, score); line = in.readLine(); } if (showAuc) { acc_output.printf(Locale.ENGLISH, "AUC , %.2f%n", collector.auc()); } if (showConfusion) { Matrix m = collector.confusion(); acc_output.printf(Locale.ENGLISH, "confusion, [[%.1f %.1f], [%.1f %.1f]]%n", m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); m = collector.entropy(); acc_output.printf(Locale.ENGLISH, "entropy, [[%.1f %.1f], [%.1f %.1f]]%n", m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); } output.close(); acc_output.close(); } }
From source file:haflow.component.mahout.logistic.TrainLogistic.java
License:Apache License
static void mainToOutput(String[] args) throws Exception { if (parseArgs(args)) { double logPEstimate = 0; int samples = 0; OutputStream o = HdfsUtil.writeHdfs(inforFile); PrintWriter output = new PrintWriter(o, true); CsvRecordFactory csv = lmp.getCsvRecordFactory(); OnlineLogisticRegression lr = lmp.createRegression(); for (int pass = 0; pass < passes; pass++) { BufferedReader in = new BufferedReader(new InputStreamReader(HdfsUtil.open(inputFile))); try { // read variable names csv.firstLine(in.readLine()); String line = in.readLine(); while (line != null) { // for each new line, get target and predictors Vector input = new RandomAccessSparseVector(lmp.getNumFeatures()); int targetValue = csv.processLine(line, input); // check performance while this is still news double logP = lr.logLikelihood(targetValue, input); if (!Double.isInfinite(logP)) { if (samples < 20) { logPEstimate = (samples * logPEstimate + logP) / (samples + 1); } else { logPEstimate = 0.95 * logPEstimate + 0.05 * logP; }//w ww . ja va 2 s .c om samples++; } double p = lr.classifyScalar(input); if (scores) { output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n", samples, targetValue, lr.currentLearningRate(), p, logP, logPEstimate); } // now update model lr.train(targetValue, input); line = in.readLine(); } } finally { Closeables.close(in, true); } } //OutputStream modelOutput = new FileOutputStream(outputFile); OutputStream modelOutput = HdfsUtil.writeHdfs(outputFile); try { lmp.saveTo(modelOutput); } finally { Closeables.close(modelOutput, false); } output.println(lmp.getNumFeatures()); output.println(lmp.getTargetVariable() + " ~ "); String sep = ""; for (String v : csv.getTraceDictionary().keySet()) { double weight = predictorWeight(lr, 0, csv, v); if (weight != 0) { output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v); sep = " + "; } } output.printf("%n"); model = lr; for (int row = 0; row < lr.getBeta().numRows(); row++) { for (String key : csv.getTraceDictionary().keySet()) { double weight = predictorWeight(lr, row, csv, key); if (weight != 0) { output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight); } } for (int column = 0; column < lr.getBeta().numCols(); column++) { output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column)); } output.println(); } output.close(); } }
From source file:javaapplication3.RunLogistic.java
public static void main(String[] args) throws IOException { // TODO code application logic here Auc collector = new Auc(); LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(modelFile)); CsvRecordFactory csv = lmp.getCsvRecordFactory(); OnlineLogisticRegression lr = lmp.createRegression(); BufferedReader in = open(inputFile); String line = in.readLine();/* w ww . j a v a 2 s . c om*/ csv.firstLine(line); line = in.readLine(); int correct = 0; int wrong = 0; while (line != null) { Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); int target = csv.processLine(line, v); System.out.println(line); String[] split = line.split(","); double score = lr.classifyFull(v).maxValueIndex(); if (score == target) correct++; else wrong++; System.out.println("Target is: " + target + " Score: " + score); line = in.readLine(); collector.add(target, score); } double posto = ((double) wrong / (double) (correct + wrong)) * 100; System.out.println("Total: " + (correct + wrong) + " Correct: " + correct + " Wrong: " + wrong + " Wrong pct: " + posto + "%"); //PrintWriter output = null; Matrix m = collector.confusion(); //output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); System.out.println("Confusion:" + m.get(0, 0) + " " + m.get(1, 0) + "\n \t " + m.get(0, 1) + " " + m.get(1, 1) + " "); // m = collector.entropy(); //output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n",m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); }