List of usage examples for org.apache.mahout.classifier.sgd OnlineLogisticRegression getBeta
public Matrix getBeta()
From source file:TrainLogistic.java
License:Apache License
private static void saveTo(FileOutputStream modelOutput, OnlineLogisticRegression lr) { PrintWriter w = new PrintWriter(new OutputStreamWriter(modelOutput)); String str = new String(" "); //System.out.printf("%d columns\n",lr.getBeta().numCols()); System.out.printf("Now, writing file...\n"); for (int column = 0; column < lr.getBeta().numCols(); column++) { //System.out.printf("%f, ", lr.getBeta().get(0, column)); str = java.lang.String.format("%f\n", lr.getBeta().get(0, column)); w.write(str);//w w w . j ava 2 s. c o m w.flush(); } w.close(); }
From source file:TrainLogistic.java
License:Apache License
private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {/*from www . j a v a2 s . c o m*/ double weight = 0; for (Integer column : csv.getTraceDictionary().get(predictor)) { weight += lr.getBeta().get(row, column); } return weight; }
From source file:com.cloudera.knittingboar.sgd.olr.TestBaseOLR_Train20Newsgroups.java
License:Apache License
public void testTrainNewsGroups() throws IOException { File base = new File("/Users/jpatterson/Downloads/datasets/20news-bydate/20news-bydate-train/"); overallCounts = HashMultiset.create(); long startTime = System.currentTimeMillis(); // p.269 --------------------------------------------------------- Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>(); // encodes the text content in both the subject and the body of the email FeatureVectorEncoder encoder = new StaticWordValueEncoder("body"); encoder.setProbes(2);/*from w ww .java 2 s .c o m*/ encoder.setTraceDictionary(traceDictionary); // provides a constant offset that the model can use to encode the average frequency // of each class FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept"); bias.setTraceDictionary(traceDictionary); // used to encode the number of lines in a message FeatureVectorEncoder lines = new ConstantValueEncoder("Lines"); lines.setTraceDictionary(traceDictionary); FeatureVectorEncoder logLines = new ConstantValueEncoder("LogLines"); logLines.setTraceDictionary(traceDictionary); Dictionary newsGroups = new Dictionary(); // matches the OLR setup on p.269 --------------- // stepOffset, decay, and alpha --- describe how the learning rate decreases // lambda: amount of regularization // learningRate: amount of initial learning rate OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression(20, FEATURES, new L1()).alpha(1) .stepOffset(1000).decayExponent(0.9).lambda(3.0e-5).learningRate(20); // bottom of p.269 ------------------------------ // because OLR expects to get integer class IDs for the target variable during training // we need a dictionary to convert the target variable (the newsgroup name) // to an integer, which is the newsGroup object List<File> files = new ArrayList<File>(); for (File newsgroup : base.listFiles()) { newsGroups.intern(newsgroup.getName()); files.addAll(Arrays.asList(newsgroup.listFiles())); } // mix up the files, helps training in OLR Collections.shuffle(files); System.out.printf("%d training files\n", files.size()); // p.270 ----- metrics to track lucene's parsing mechanics, progress, performance of OLR ------------ double averageLL = 0.0; double averageCorrect = 0.0; double averageLineCount = 0.0; int k = 0; double step = 0.0; int[] bumps = new int[] { 1, 2, 5 }; double lineCount = 0; // last line on p.269 Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31); Splitter onColon = Splitter.on(":").trimResults(); int input_file_count = 0; // ----- p.270 ------------ "reading and tokenzing the data" --------- for (File file : files) { BufferedReader reader = new BufferedReader(new FileReader(file)); input_file_count++; // identify newsgroup ---------------- // convert newsgroup name to unique id // ----------------------------------- String ng = file.getParentFile().getName(); int actual = newsGroups.intern(ng); Multiset<String> words = ConcurrentHashMultiset.create(); // check for line count header ------- String line = reader.readLine(); while (line != null && line.length() > 0) { // if this is a line that has a line count, let's pull that value out ------ if (line.startsWith("Lines:")) { String count = Iterables.get(onColon.split(line), 1); try { lineCount = Integer.parseInt(count); averageLineCount += (lineCount - averageLineCount) / Math.min(k + 1, 1000); } catch (NumberFormatException e) { // if anything goes wrong in parse: just use the avg count lineCount = averageLineCount; } } boolean countHeader = (line.startsWith("From:") || line.startsWith("Subject:") || line.startsWith("Keywords:") || line.startsWith("Summary:")); // loop through the lines in the file, while the line starts with: " " do { // get a reader for this specific string ------ StringReader in = new StringReader(line); // ---- count words in header --------- if (countHeader) { countWords(analyzer, words, in); } // iterate to the next string ---- line = reader.readLine(); } while (line.startsWith(" ")); } // while (lines in header) { // -------- count words in body ---------- countWords(analyzer, words, reader); reader.close(); // ----- p.271 ----------- Vector v = new RandomAccessSparseVector(FEATURES); // original value does nothing in a ContantValueEncoder bias.addToVector("", 1, v); // original value does nothing in a ContantValueEncoder lines.addToVector("", lineCount / 30, v); // original value does nothing in a ContantValueEncoder logLines.addToVector("", Math.log(lineCount + 1), v); // now scan through all the words and add them for (String word : words.elementSet()) { encoder.addToVector(word, Math.log(1 + words.count(word)), v); } //Utils.PrintVectorNonZero(v); // calc stats --------- double mu = Math.min(k + 1, 200); double ll = learningAlgorithm.logLikelihood(actual, v); averageLL = averageLL + (ll - averageLL) / mu; Vector p = new DenseVector(20); learningAlgorithm.classifyFull(p, v); int estimated = p.maxValueIndex(); int correct = (estimated == actual ? 1 : 0); averageCorrect = averageCorrect + (correct - averageCorrect) / mu; learningAlgorithm.train(actual, v); k++; int bump = bumps[(int) Math.floor(step) % bumps.length]; int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); if (k % (bump * scale) == 0) { step += 0.25; System.out.printf("%10d %10.3f %10.3f %10.2f %s %s\n", k, ll, averageLL, averageCorrect * 100, ng, newsGroups.values().get(estimated)); } learningAlgorithm.close(); /* if (k>4) { break; } */ } Utils.PrintVectorSection(learningAlgorithm.getBeta().viewRow(0), 3); long endTime = System.currentTimeMillis(); //System.out.println("That took " + (endTime - startTime) + " milliseconds"); long duration = (endTime - startTime); System.out.println("Processed Input Files: " + input_file_count + ", time: " + duration + "ms"); ModelSerializer.writeBinary("/tmp/olr-news-group.model", learningAlgorithm); // learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); }
From source file:com.memonews.mahout.sentiment.SGDHelper.java
License:Apache License
static void analyzeState(final SGDInfo info, final int leakType, final int k, final State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best) throws IOException { final int bump = info.getBumps()[(int) Math.floor(info.getStep()) % info.getBumps().length]; final int scale = (int) Math.pow(10, Math.floor(info.getStep() / info.getBumps().length)); double maxBeta; double nonZeros; double positive; double norm;/* www . j ava2s . c o m*/ double lambda = 0; double mu = 0; if (best != null) { final CrossFoldLearner state = best.getPayload().getLearner(); info.setAverageCorrect(state.percentCorrect()); info.setAverageLL(state.logLikelihood()); final OnlineLogisticRegression model = state.getModels().get(0); // finish off pending regularization model.close(); final Matrix beta = model.getBeta(); maxBeta = beta.aggregate(Functions.MAX, Functions.ABS); nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() { @Override public double apply(final double v) { return Math.abs(v) > 1.0e-6 ? 1 : 0; } }); positive = beta.aggregate(Functions.PLUS, new DoubleFunction() { @Override public double apply(final double v) { return v > 0 ? 1 : 0; } }); norm = beta.aggregate(Functions.PLUS, Functions.ABS); lambda = best.getMappedParams()[0]; mu = best.getMappedParams()[1]; } else { maxBeta = 0; nonZeros = 0; positive = 0; norm = 0; } if (k % (bump * scale) == 0) { if (best != null) { ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model", best.getPayload().getLearner().getModels().get(0)); } info.setStep(info.getStep() + 0.25); System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu); System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, info.getAverageLL(), info.getAverageCorrect() * 100, LEAK_LABELS[leakType % 3]); } }
From source file:com.ml.ira.algos.TrainLogistic.java
License:Apache License
static void mainToOutput(String[] args, PrintWriter output) throws Exception { if (parseArgs(args)) { double logPEstimate = 0; int samples = 0; System.out.println("fieldNames: " + fieldNames); long ts = System.currentTimeMillis(); CsvRecordFactory csv = lmp.getCsvRecordFactory(); OnlineLogisticRegression lr = lmp.createRegression(); for (int pass = 0; pass < passes; pass++) { System.out.println("at Round: " + pass); BufferedReader in = open(inputFile); try { // read variable names String line;/*from w w w. j a v a2 s.c o m*/ if (fieldNames != null && fieldNames.length() > 0) { csv.firstLine(fieldNames); } else { csv.firstLine(in.readLine()); } line = in.readLine(); while (line != null) { // for each new line, get target and predictors Vector input = new RandomAccessSparseVector(lmp.getNumFeatures()); int targetValue = csv.processLine(line, input); // check performance while this is still news double logP = lr.logLikelihood(targetValue, input); if (!Double.isInfinite(logP)) { if (samples < 20) { logPEstimate = (samples * logPEstimate + logP) / (samples + 1); } else { logPEstimate = 0.95 * logPEstimate + 0.05 * logP; } samples++; } double p = lr.classifyScalar(input); if (scores) { output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n", samples, targetValue, lr.currentLearningRate(), p, logP, logPEstimate); } // now update model lr.train(targetValue, input); line = in.readLine(); } } finally { Closeables.close(in, true); } output.println("duration: " + (System.currentTimeMillis() - ts)); } if (outputFile.startsWith("hdfs://")) { lmp.saveTo(new Path(outputFile)); } else { OutputStream modelOutput = new FileOutputStream(outputFile); try { lmp.saveTo(modelOutput); } finally { Closeables.close(modelOutput, false); } } output.println("duration: " + (System.currentTimeMillis() - ts)); output.println(lmp.getNumFeatures()); output.println(lmp.getTargetVariable() + " ~ "); String sep = ""; for (String v : csv.getTraceDictionary().keySet()) { double weight = predictorWeight(lr, 0, csv, v); if (weight != 0) { output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v); sep = " + "; } } output.printf("%n"); model = lr; for (int row = 0; row < lr.getBeta().numRows(); row++) { for (String key : csv.getTraceDictionary().keySet()) { double weight = predictorWeight(lr, row, csv, key); if (weight != 0) { output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight); } } for (int column = 0; column < lr.getBeta().numCols(); column++) { output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column)); } output.println(); } } }
From source file:com.tdunning.ch16.train.TrainNewsGroups.java
License:Apache License
public static void main(String[] args) throws IOException { File base = new File(args[0]); int leakType = 0; if (args.length > 1) { leakType = Integer.parseInt(args[1]); }/* w ww.j a v a 2s .c o m*/ Dictionary newsGroups = new Dictionary(); encoder.setProbes(2); AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, FEATURES, new L1()); learningAlgorithm.setInterval(800); learningAlgorithm.setAveragingWindow(500); List<File> files = Lists.newArrayList(); File[] directories = base.listFiles(); Arrays.sort(directories, Ordering.usingToString()); for (File newsgroup : directories) { if (newsgroup.isDirectory()) { newsGroups.intern(newsgroup.getName()); files.addAll(Arrays.asList(newsgroup.listFiles())); } } Collections.shuffle(files); System.out.printf("%d training files\n", files.size()); System.out.printf("%s\n", Arrays.asList(directories)); double averageLL = 0; double averageCorrect = 0; int k = 0; double step = 0; int[] bumps = { 1, 2, 5 }; for (File file : files) { String ng = file.getParentFile().getName(); int actual = newsGroups.intern(ng); Vector v = encodeFeatureVector(file); learningAlgorithm.train(actual, v); k++; int bump = bumps[(int) Math.floor(step) % bumps.length]; int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest(); double maxBeta; double nonZeros; double positive; double norm; double lambda = 0; double mu = 0; if (best != null) { CrossFoldLearner state = best.getPayload().getLearner(); averageCorrect = state.percentCorrect(); averageLL = state.logLikelihood(); OnlineLogisticRegression model = state.getModels().get(0); // finish off pending regularization model.close(); Matrix beta = model.getBeta(); maxBeta = beta.aggregate(Functions.MAX, Functions.ABS); nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() { @Override public double apply(double v) { return Math.abs(v) > 1.0e-6 ? 1 : 0; } }); positive = beta.aggregate(Functions.PLUS, new DoubleFunction() { @Override public double apply(double v) { return v > 0 ? 1 : 0; } }); norm = beta.aggregate(Functions.PLUS, Functions.ABS); lambda = learningAlgorithm.getBest().getMappedParams()[0]; mu = learningAlgorithm.getBest().getMappedParams()[1]; } else { maxBeta = 0; nonZeros = 0; positive = 0; norm = 0; } if (k % (bump * scale) == 0) { if (learningAlgorithm.getBest() != null) { ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model", learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); } step += 0.25; System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu); System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, averageLL, averageCorrect * 100, LEAK_LABELS[leakType % 3]); } } learningAlgorithm.close(); dissect(newsGroups, learningAlgorithm, files); System.out.println("exiting main"); ModelSerializer.writeBinary("/tmp/news-group.model", learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); }
From source file:haflow.component.mahout.logistic.TrainLogistic.java
License:Apache License
static void mainToOutput(String[] args) throws Exception { if (parseArgs(args)) { double logPEstimate = 0; int samples = 0; OutputStream o = HdfsUtil.writeHdfs(inforFile); PrintWriter output = new PrintWriter(o, true); CsvRecordFactory csv = lmp.getCsvRecordFactory(); OnlineLogisticRegression lr = lmp.createRegression(); for (int pass = 0; pass < passes; pass++) { BufferedReader in = new BufferedReader(new InputStreamReader(HdfsUtil.open(inputFile))); try { // read variable names csv.firstLine(in.readLine()); String line = in.readLine(); while (line != null) { // for each new line, get target and predictors Vector input = new RandomAccessSparseVector(lmp.getNumFeatures()); int targetValue = csv.processLine(line, input); // check performance while this is still news double logP = lr.logLikelihood(targetValue, input); if (!Double.isInfinite(logP)) { if (samples < 20) { logPEstimate = (samples * logPEstimate + logP) / (samples + 1); } else { logPEstimate = 0.95 * logPEstimate + 0.05 * logP; }/* w ww . ja va 2 s . c om*/ samples++; } double p = lr.classifyScalar(input); if (scores) { output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n", samples, targetValue, lr.currentLearningRate(), p, logP, logPEstimate); } // now update model lr.train(targetValue, input); line = in.readLine(); } } finally { Closeables.close(in, true); } } //OutputStream modelOutput = new FileOutputStream(outputFile); OutputStream modelOutput = HdfsUtil.writeHdfs(outputFile); try { lmp.saveTo(modelOutput); } finally { Closeables.close(modelOutput, false); } output.println(lmp.getNumFeatures()); output.println(lmp.getTargetVariable() + " ~ "); String sep = ""; for (String v : csv.getTraceDictionary().keySet()) { double weight = predictorWeight(lr, 0, csv, v); if (weight != 0) { output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v); sep = " + "; } } output.printf("%n"); model = lr; for (int row = 0; row < lr.getBeta().numRows(); row++) { for (String key : csv.getTraceDictionary().keySet()) { double weight = predictorWeight(lr, row, csv, key); if (weight != 0) { output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight); } } for (int column = 0; column < lr.getBeta().numCols(); column++) { output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column)); } output.println(); } output.close(); } }