List of usage examples for org.apache.mahout.classifier.sgd L1 L1
L1
From source file:br.com.sitedoph.mahout_examples.BankMarketingClassificationMain.java
License:Apache License
public static void main(String[] args) throws Exception { List<TelephoneCall> calls = Lists.newArrayList(new TelephoneCallParser("bank-full.csv")); double heldOutPercentage = 0.10; double biggestScore = 0.0; for (int run = 0; run < 20; run++) { Collections.shuffle(calls); int cutoff = (int) (heldOutPercentage * calls.size()); List<TelephoneCall> testAccuracyData = calls.subList(0, cutoff); List<TelephoneCall> trainData = calls.subList(cutoff, calls.size()); List<TelephoneCall> testUnknownData = new ArrayList<>(); testUnknownData.add(getUnknownTelephoneCall(trainData)); OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, TelephoneCall.FEATURES, new L1()).learningRate(1).alpha(1).lambda(0.000001).stepOffset(10000).decayExponent(0.2); for (int pass = 0; pass < 20; pass++) { for (TelephoneCall observation : trainData) { lr.train(observation.getTarget(), observation.asVector()); }/*from www . j a v a 2s . c o m*/ Auc eval = new Auc(0.5); for (TelephoneCall testCall : testAccuracyData) { biggestScore = evaluateTheCallAndGetBiggestScore(biggestScore, lr, eval, testCall); } System.out.printf("run: %-5d pass: %-5d current learning rate: %-5.4f \teval auc %-5.4f\n", run, pass, lr.currentLearningRate(), eval.auc()); for (TelephoneCall testCall : testUnknownData) { final double score = lr.classifyScalar(testCall.asVector()); System.out.println(" score: " + score + " accuracy " + eval.auc() + " call fields: " + testCall.getFields()); } } } }
From source file:chapter4.src.logistic.LogisticModelParametersPredict.java
License:Apache License
/** * Creates a logistic regression trainer using the parameters collected here. * * @return The newly allocated OnlineLogisticRegression object *//* w w w . ja v a2 s . co m*/ public OnlineLogisticRegression createRegression() { if (lr == null) { lr = new OnlineLogisticRegression(getMaxTargetCategories(), getNumFeatures(), new L1()) .lambda(getLambda()).learningRate(getLearningRate()).alpha(1 - 1.0e-3); } return lr; }
From source file:com.cloudera.knittingboar.records.TestTwentyNewsgroupsCustomRecordParseOLRRun.java
License:Apache License
@Test public void testRecordFactoryOnDatasetShard() throws Exception { // TODO a test with assertions is not a test // p.270 ----- metrics to track lucene's parsing mechanics, progress, // performance of OLR ------------ double averageLL = 0.0; double averageCorrect = 0.0; int k = 0;//from ww w .j av a 2 s.co m double step = 0.0; int[] bumps = new int[] { 1, 2, 5 }; TwentyNewsgroupsRecordFactory rec_factory = new TwentyNewsgroupsRecordFactory("\t"); // rec_factory.setClassSplitString("\t"); JobConf job = new JobConf(defaultConf); long block_size = localFs.getDefaultBlockSize(workDir); LOG.info("default block size: " + (block_size / 1024 / 1024) + "MB"); // matches the OLR setup on p.269 --------------- // stepOffset, decay, and alpha --- describe how the learning rate decreases // lambda: amount of regularization // learningRate: amount of initial learning rate @SuppressWarnings("resource") OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression(20, FEATURES, new L1()).alpha(1) .stepOffset(1000).decayExponent(0.9).lambda(3.0e-5).learningRate(20); FileInputFormat.setInputPaths(job, workDir); // try splitting the file in a variety of sizes TextInputFormat format = new TextInputFormat(); format.configure(job); Text value = new Text(); int numSplits = 1; InputSplit[] splits = format.getSplits(job, numSplits); LOG.info("requested " + numSplits + " splits, splitting: got = " + splits.length); LOG.info("---- debug splits --------- "); rec_factory.Debug(); int total_read = 0; for (int x = 0; x < splits.length; x++) { LOG.info("> Split [" + x + "]: " + splits[x].getLength()); int count = 0; InputRecordsSplit custom_reader = new InputRecordsSplit(job, splits[x]); while (custom_reader.next(value)) { Vector v = new RandomAccessSparseVector(TwentyNewsgroupsRecordFactory.FEATURES); int actual = rec_factory.processLine(value.toString(), v); String ng = rec_factory.GetNewsgroupNameByID(actual); // calc stats --------- double mu = Math.min(k + 1, 200); double ll = learningAlgorithm.logLikelihood(actual, v); averageLL = averageLL + (ll - averageLL) / mu; Vector p = new DenseVector(20); learningAlgorithm.classifyFull(p, v); int estimated = p.maxValueIndex(); int correct = (estimated == actual ? 1 : 0); averageCorrect = averageCorrect + (correct - averageCorrect) / mu; learningAlgorithm.train(actual, v); k++; int bump = bumps[(int) Math.floor(step) % bumps.length]; int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); if (k % (bump * scale) == 0) { step += 0.25; LOG.info(String.format("%10d %10.3f %10.3f %10.2f %s %s", k, ll, averageLL, averageCorrect * 100, ng, rec_factory.GetNewsgroupNameByID(estimated))); } learningAlgorithm.close(); count++; } LOG.info("read: " + count + " records for split " + x); total_read += count; } // for each split LOG.info("total read across all splits: " + total_read); rec_factory.Debug(); }
From source file:com.cloudera.knittingboar.sgd.olr.TestBaseOLR_Train20Newsgroups.java
License:Apache License
public void testTrainNewsGroups() throws IOException { File base = new File("/Users/jpatterson/Downloads/datasets/20news-bydate/20news-bydate-train/"); overallCounts = HashMultiset.create(); long startTime = System.currentTimeMillis(); // p.269 --------------------------------------------------------- Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>(); // encodes the text content in both the subject and the body of the email FeatureVectorEncoder encoder = new StaticWordValueEncoder("body"); encoder.setProbes(2);/*from www. j ava 2 s. co m*/ encoder.setTraceDictionary(traceDictionary); // provides a constant offset that the model can use to encode the average frequency // of each class FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept"); bias.setTraceDictionary(traceDictionary); // used to encode the number of lines in a message FeatureVectorEncoder lines = new ConstantValueEncoder("Lines"); lines.setTraceDictionary(traceDictionary); FeatureVectorEncoder logLines = new ConstantValueEncoder("LogLines"); logLines.setTraceDictionary(traceDictionary); Dictionary newsGroups = new Dictionary(); // matches the OLR setup on p.269 --------------- // stepOffset, decay, and alpha --- describe how the learning rate decreases // lambda: amount of regularization // learningRate: amount of initial learning rate OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression(20, FEATURES, new L1()).alpha(1) .stepOffset(1000).decayExponent(0.9).lambda(3.0e-5).learningRate(20); // bottom of p.269 ------------------------------ // because OLR expects to get integer class IDs for the target variable during training // we need a dictionary to convert the target variable (the newsgroup name) // to an integer, which is the newsGroup object List<File> files = new ArrayList<File>(); for (File newsgroup : base.listFiles()) { newsGroups.intern(newsgroup.getName()); files.addAll(Arrays.asList(newsgroup.listFiles())); } // mix up the files, helps training in OLR Collections.shuffle(files); System.out.printf("%d training files\n", files.size()); // p.270 ----- metrics to track lucene's parsing mechanics, progress, performance of OLR ------------ double averageLL = 0.0; double averageCorrect = 0.0; double averageLineCount = 0.0; int k = 0; double step = 0.0; int[] bumps = new int[] { 1, 2, 5 }; double lineCount = 0; // last line on p.269 Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31); Splitter onColon = Splitter.on(":").trimResults(); int input_file_count = 0; // ----- p.270 ------------ "reading and tokenzing the data" --------- for (File file : files) { BufferedReader reader = new BufferedReader(new FileReader(file)); input_file_count++; // identify newsgroup ---------------- // convert newsgroup name to unique id // ----------------------------------- String ng = file.getParentFile().getName(); int actual = newsGroups.intern(ng); Multiset<String> words = ConcurrentHashMultiset.create(); // check for line count header ------- String line = reader.readLine(); while (line != null && line.length() > 0) { // if this is a line that has a line count, let's pull that value out ------ if (line.startsWith("Lines:")) { String count = Iterables.get(onColon.split(line), 1); try { lineCount = Integer.parseInt(count); averageLineCount += (lineCount - averageLineCount) / Math.min(k + 1, 1000); } catch (NumberFormatException e) { // if anything goes wrong in parse: just use the avg count lineCount = averageLineCount; } } boolean countHeader = (line.startsWith("From:") || line.startsWith("Subject:") || line.startsWith("Keywords:") || line.startsWith("Summary:")); // loop through the lines in the file, while the line starts with: " " do { // get a reader for this specific string ------ StringReader in = new StringReader(line); // ---- count words in header --------- if (countHeader) { countWords(analyzer, words, in); } // iterate to the next string ---- line = reader.readLine(); } while (line.startsWith(" ")); } // while (lines in header) { // -------- count words in body ---------- countWords(analyzer, words, reader); reader.close(); // ----- p.271 ----------- Vector v = new RandomAccessSparseVector(FEATURES); // original value does nothing in a ContantValueEncoder bias.addToVector("", 1, v); // original value does nothing in a ContantValueEncoder lines.addToVector("", lineCount / 30, v); // original value does nothing in a ContantValueEncoder logLines.addToVector("", Math.log(lineCount + 1), v); // now scan through all the words and add them for (String word : words.elementSet()) { encoder.addToVector(word, Math.log(1 + words.count(word)), v); } //Utils.PrintVectorNonZero(v); // calc stats --------- double mu = Math.min(k + 1, 200); double ll = learningAlgorithm.logLikelihood(actual, v); averageLL = averageLL + (ll - averageLL) / mu; Vector p = new DenseVector(20); learningAlgorithm.classifyFull(p, v); int estimated = p.maxValueIndex(); int correct = (estimated == actual ? 1 : 0); averageCorrect = averageCorrect + (correct - averageCorrect) / mu; learningAlgorithm.train(actual, v); k++; int bump = bumps[(int) Math.floor(step) % bumps.length]; int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); if (k % (bump * scale) == 0) { step += 0.25; System.out.printf("%10d %10.3f %10.3f %10.2f %s %s\n", k, ll, averageLL, averageCorrect * 100, ng, newsGroups.values().get(estimated)); } learningAlgorithm.close(); /* if (k>4) { break; } */ } Utils.PrintVectorSection(learningAlgorithm.getBeta().viewRow(0), 3); long endTime = System.currentTimeMillis(); //System.out.println("That took " + (endTime - startTime) + " milliseconds"); long duration = (endTime - startTime); System.out.println("Processed Input Files: " + input_file_count + ", time: " + duration + "ms"); ModelSerializer.writeBinary("/tmp/olr-news-group.model", learningAlgorithm); // learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); }
From source file:com.cloudera.knittingboar.sgd.TestParallelOnlineLogisticRegression.java
License:Apache License
public void testCreateLR() { int categories = 2; int numFeatures = 5; double lambda = 1.0e-4; double learning_rate = 50; ParallelOnlineLogisticRegression plr = new ParallelOnlineLogisticRegression(categories, numFeatures, new L1()).lambda(lambda).learningRate(learning_rate).alpha(1 - 1.0e-3); assertEquals(plr.getLambda(), 1.0e-4); }
From source file:com.cloudera.knittingboar.sgd.TestParallelOnlineLogisticRegression.java
License:Apache License
public void testTrainMechanics() { int categories = 2; int numFeatures = 5; double lambda = 1.0e-4; double learning_rate = 10; ParallelOnlineLogisticRegression plr = new ParallelOnlineLogisticRegression(categories, numFeatures, new L1()).lambda(lambda).learningRate(learning_rate).alpha(1 - 1.0e-3); Vector input = new RandomAccessSparseVector(numFeatures); for (int x = 0; x < numFeatures; x++) { input.set(x, x);/*from w ww .ja v a 2 s.c o m*/ } plr.train(0, input); plr.train(0, input); plr.train(0, input); }
From source file:com.cloudera.knittingboar.sgd.TestParallelOnlineLogisticRegression.java
License:Apache License
public void testPOLRInternalBuffers() { System.out.println("testPOLRInternalBuffers --------------"); int categories = 2; int numFeatures = 5; double lambda = 1.0e-4; double learning_rate = 10; ArrayList<Vector> trainingSet_0 = new ArrayList<Vector>(); for (int s = 0; s < 1; s++) { Vector input = new RandomAccessSparseVector(numFeatures); for (int x = 0; x < numFeatures; x++) { input.set(x, x);//ww w . ja v a2 s . c om } trainingSet_0.add(input); } // for ParallelOnlineLogisticRegression plr_agent_0 = new ParallelOnlineLogisticRegression(categories, numFeatures, new L1()).lambda(lambda).learningRate(learning_rate).alpha(1 - 1.0e-3); System.out.println("Beta: "); //Utils.PrintVectorNonZero(plr_agent_0.getBeta().getRow(0)); Utils.PrintVectorNonZero(plr_agent_0.getBeta().viewRow(0)); System.out.println("\nGamma: "); //Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().getRow(0)); Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().viewRow(0)); plr_agent_0.train(0, trainingSet_0.get(0)); System.out.println("Beta: "); //Utils.PrintVectorNonZero(plr_agent_0.noReallyGetBeta().getRow(0)); Utils.PrintVectorNonZero(plr_agent_0.noReallyGetBeta().viewRow(0)); System.out.println("\nGamma: "); //Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().getRow(0)); Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().viewRow(0)); }
From source file:com.cloudera.knittingboar.sgd.TestParallelOnlineLogisticRegression.java
License:Apache License
public void testLocalGradientFlush() { System.out.println("\n\n\ntestLocalGradientFlush --------------"); int categories = 2; int numFeatures = 5; double lambda = 1.0e-4; double learning_rate = 10; ArrayList<Vector> trainingSet_0 = new ArrayList<Vector>(); for (int s = 0; s < 1; s++) { Vector input = new RandomAccessSparseVector(numFeatures); for (int x = 0; x < numFeatures; x++) { input.set(x, x);//from w ww .j a v a 2 s. c o m } trainingSet_0.add(input); } // for ParallelOnlineLogisticRegression plr_agent_0 = new ParallelOnlineLogisticRegression(categories, numFeatures, new L1()).lambda(lambda).learningRate(learning_rate).alpha(1 - 1.0e-3); plr_agent_0.train(0, trainingSet_0.get(0)); System.out.println("\nGamma: "); Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().viewRow(0)); plr_agent_0.FlushGamma(); System.out.println("Flushing Gamma ...... "); System.out.println("\nGamma: "); Utils.PrintVector(plr_agent_0.gamma.getMatrix().viewRow(0)); for (int x = 0; x < numFeatures; x++) { assertEquals(plr_agent_0.gamma.getMatrix().get(0, x), 0.0); } }
From source file:com.memonews.mahout.sentiment.SentimentModelTrainer.java
License:Apache License
public static void main(final String[] args) throws IOException { final File base = new File(args[0]); final String modelPath = args.length > 1 ? args[1] : "target/model"; final Multiset<String> overallCounts = HashMultiset.create(); final Dictionary newsGroups = new Dictionary(); final SentimentModelHelper helper = new SentimentModelHelper(); helper.getEncoder().setProbes(2);// w w w .j a v a2 s . c om final AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(2, SentimentModelHelper.FEATURES, new L1()); learningAlgorithm.setInterval(800); learningAlgorithm.setAveragingWindow(500); final List<File> files = Lists.newArrayList(); for (final File newsgroup : base.listFiles()) { if (newsgroup.isDirectory()) { newsGroups.intern(newsgroup.getName()); files.addAll(Arrays.asList(newsgroup.listFiles())); } } Collections.shuffle(files); System.out.printf("%d training files\n", files.size()); final SGDInfo info = new SGDInfo(); int k = 0; for (final File file : files) { final String ng = file.getParentFile().getName(); final int actual = newsGroups.intern(ng); final Vector v = helper.encodeFeatureVector(file, overallCounts); learningAlgorithm.train(actual, v); k++; final State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest(); SGDHelper.analyzeState(info, 0, k, best); } learningAlgorithm.close(); SGDHelper.dissect(0, newsGroups, learningAlgorithm, files, overallCounts); System.out.println("exiting main"); ModelSerializer.writeBinary(modelPath, learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); final List<Integer> counts = Lists.newArrayList(); System.out.printf("Word counts\n"); for (final String count : overallCounts.elementSet()) { counts.add(overallCounts.count(count)); } Collections.sort(counts, Ordering.natural().reverse()); k = 0; for (final Integer count : counts) { System.out.printf("%d\t%d\n", k, count); k++; if (k > 1000) { break; } } }
From source file:com.ml.ira.algos.AdaptiveLogisticModelParameters.java
License:Apache License
private static PriorFunction createPrior(String cmd, double priorOption) { if (cmd == null) { return null; }//from w w w. j a v a2 s .co m if ("L1".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { return new L1(); } if ("L2".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { return new L2(); } if ("UP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { return new UniformPrior(); } if ("TP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { return new TPrior(priorOption); } if ("EBP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { return new ElasticBandPrior(priorOption); } return null; }