List of usage examples for org.apache.mahout.vectorizer.encoders ConstantValueEncoder ConstantValueEncoder
public ConstantValueEncoder(String name)
From source file:SimpleCsvExamples.java
License:Apache License
public static void main(String[] args) throws IOException { FeatureVectorEncoder[] encoder = new FeatureVectorEncoder[FIELDS]; for (int i = 0; i < FIELDS; i++) { encoder[i] = new ConstantValueEncoder("v" + 1); }//from www . j av a 2 s . c o m OnlineSummarizer[] s = new OnlineSummarizer[FIELDS]; for (int i = 0; i < FIELDS; i++) { s[i] = new OnlineSummarizer(); } long t0 = System.currentTimeMillis(); Vector v = new DenseVector(1000); if ("--generate".equals(args[0])) { PrintWriter out = new PrintWriter( new OutputStreamWriter(new FileOutputStream(new File(args[2])), Charsets.UTF_8)); try { int n = Integer.parseInt(args[1]); for (int i = 0; i < n; i++) { Line x = Line.generate(); out.println(x); } } finally { Closeables.close(out, false); } } else if ("--parse".equals(args[0])) { BufferedReader in = Files.newReader(new File(args[1]), Charsets.UTF_8); double total = 0; try { String line = in.readLine(); while (line != null) { v.assign(0); Line x = new Line(line); for (int i = 0; i < FIELDS; i++) { double z = x.getDouble(i); total += z; //s[i].add(x.getDouble(i)); encoder[i].addToVector(x.get(i), v); } line = in.readLine(); } } finally { Closeables.close(in, true); } // String separator = ""; // for (int i = 0; i < FIELDS; i++) { // System.out.printf("%s%.3f", separator, s[i].getMean()); // separator = ","; // } System.out.println("total: " + total); } else if ("--fast".equals(args[0])) { FastLineReader in = new FastLineReader(new FileInputStream(args[1])); double total = 0; try { FastLine line = in.read(); while (line != null) { v.assign(0); for (int i = 0; i < FIELDS; i++) { double z = line.getDouble(i); total += z; //s[i].add(z); encoder[i].addToVector((byte[]) null, z, v); } line = in.read(); } } finally { Closeables.close(in, true); } // String separator = ""; // for (int i = 0; i < FIELDS; i++) { // System.out.printf("%s%.3f", separator, s[i].getMean()); // separator = ","; // } System.out.println("total: " + total); } System.out.printf("\nElapsed time = %.3f%n", (System.currentTimeMillis() - t0) / 1000.0); }
From source file:com.cloudera.knittingboar.records.RCV1RecordFactory.java
License:Apache License
public RCV1RecordFactory() { this.encoder = new ConstantValueEncoder("body_values"); }
From source file:com.cloudera.knittingboar.records.RCV1RecordFactory.java
License:Apache License
public static void ScanFile(String file, int debug_break_cnt) throws IOException { ConstantValueEncoder encoder_test = new ConstantValueEncoder("test"); BufferedReader reader = null; // Collection<String> words int line_count = 0; Multiset<String> class_count = ConcurrentHashMultiset.create(); Multiset<String> namespaces = ConcurrentHashMultiset.create(); try {/*from w ww . j ava 2 s. c o m*/ // System.out.println( newsgroup ); reader = new BufferedReader(new FileReader(file)); String line = reader.readLine(); while (line != null && line.length() > 0) { // shard_writer.write(line + "\n"); // out += line; String[] parts = line.split(" "); // System.out.println( "Class: " + parts[0] ); class_count.add(parts[0]); namespaces.add(parts[1]); line = reader.readLine(); line_count++; Vector v = new RandomAccessSparseVector(FEATURES); for (int x = 2; x < parts.length; x++) { // encoder_test.addToVector(parts[x], v); // System.out.println( parts[x] ); String[] feature = parts[x].split(":"); int index = Integer.parseInt(feature[0]) % FEATURES; double val = Double.parseDouble(feature[1]); // System.out.println( feature[1] + " = " + val ); if (index < FEATURES) { v.set(index, val); } else { System.out.println("Could Hash: " + index + " to " + (index % FEATURES)); } } Utils.PrintVectorSectionNonZero(v, 10); System.out.println("###"); if (line_count > debug_break_cnt) { break; } } System.out.println("Total Rec Count: " + line_count); System.out.println("-------------------- "); System.out.println("Classes"); for (String word : class_count.elementSet()) { System.out.println("Class " + word + ": " + class_count.count(word) + " "); } System.out.println("-------------------- "); System.out.println("NameSpaces:"); for (String word : namespaces.elementSet()) { System.out.println("Namespace " + word + ": " + namespaces.count(word) + " "); } /* * TokenStream ts = analyzer.tokenStream("text", reader); * ts.addAttribute(CharTermAttribute.class); * * // for each word in the stream, minus non-word stuff, add word to * collection while (ts.incrementToken()) { String s = * ts.getAttribute(CharTermAttribute.class).toString(); * //System.out.print( " " + s ); //words.add(s); out += s + " "; } */ } finally { reader.close(); } // return out + "\n"; }
From source file:com.cloudera.knittingboar.records.Test20NewsgroupsBookParsing.java
License:Apache License
public void test20NewsgroupsFileScan() throws IOException { // p.270 ----- metrics to track lucene's parsing mechanics, progress, performance of OLR ------------ double averageLL = 0.0; double averageCorrect = 0.0; double averageLineCount = 0.0; int k = 0;// ww w.j a va2s . c o m double step = 0.0; int[] bumps = new int[] { 1, 2, 5 }; double lineCount = 0; Splitter onColon = Splitter.on(":").trimResults(); // last line on p.269 Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31); File base = new File("/Users/jpatterson/Downloads/datasets/20news-bydate/20-debug/"); overallCounts = HashMultiset.create(); // p.269 --------------------------------------------------------- Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>(); // encodes the text content in both the subject and the body of the email FeatureVectorEncoder encoder = new StaticWordValueEncoder("body"); encoder.setProbes(2); encoder.setTraceDictionary(traceDictionary); // provides a constant offset that the model can use to encode the average frequency // of each class FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept"); bias.setTraceDictionary(traceDictionary); // used to encode the number of lines in a message FeatureVectorEncoder lines = new ConstantValueEncoder("Lines"); lines.setTraceDictionary(traceDictionary); Dictionary newsGroups = new Dictionary(); // bottom of p.269 ------------------------------ // because OLR expects to get integer class IDs for the target variable during training // we need a dictionary to convert the target variable (the newsgroup name) // to an integer, which is the newsGroup object List<File> files = new ArrayList<File>(); for (File newsgroup : base.listFiles()) { newsGroups.intern(newsgroup.getName()); System.out.println(">> " + newsgroup.getName()); files.addAll(Arrays.asList(newsgroup.listFiles())); } // mix up the files, helps training in OLR Collections.shuffle(files); System.out.printf("%d training files\n", files.size()); // ----- p.270 ------------ "reading and tokenzing the data" --------- for (File file : files) { BufferedReader reader = new BufferedReader(new FileReader(file)); // identify newsgroup ---------------- // convert newsgroup name to unique id // ----------------------------------- String ng = file.getParentFile().getName(); int actual = newsGroups.intern(ng); Multiset<String> words = ConcurrentHashMultiset.create(); // check for line count header ------- String line = reader.readLine(); while (line != null && line.length() > 0) { // if this is a line that has a line count, let's pull that value out ------ if (line.startsWith("Lines:")) { String count = Iterables.get(onColon.split(line), 1); try { lineCount = Integer.parseInt(count); averageLineCount += (lineCount - averageLineCount) / Math.min(k + 1, 1000); } catch (NumberFormatException e) { // if anything goes wrong in parse: just use the avg count lineCount = averageLineCount; } } // which header words to actually count boolean countHeader = (line.startsWith("From:") || line.startsWith("Subject:") || line.startsWith("Keywords:") || line.startsWith("Summary:")); // we're still looking at the header at this point // loop through the lines in the file, while the line starts with: " " do { // get a reader for this specific string ------ StringReader in = new StringReader(line); // ---- count words in header --------- if (countHeader) { //System.out.println( "#### countHeader ################*************" ); countWords(analyzer, words, in); } // iterate to the next string ---- line = reader.readLine(); } while (line.startsWith(" ")); //System.out.println("[break]"); } // now we're done with the header //System.out.println("[break-header]"); // -------- count words in body ---------- countWords(analyzer, words, reader); reader.close(); /* for (String word : words.elementSet()) { //encoder.addToVector(word, Math.log(1 + words.count(word)), v); System.out.println( "> " + word + ", " + words.count(word) ); } */ } }
From source file:com.cloudera.knittingboar.records.TwentyNewsgroupsRecordFactory.java
License:Apache License
/** * Processes single line of input into: - target variable - Feature vector * /*ww w .j a va 2 s . c o m*/ * @throws Exception */ public int processLine(String line, Vector v) throws Exception { String[] parts = line.split(this.class_id_split_string); if (parts.length < 2) { throw new Exception("wtf: line not formed well."); } String newsgroup_name = parts[0]; String msg = parts[1]; // p.269 --------------------------------------------------------- Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>(); // encodes the text content in both the subject and the body of the email FeatureVectorEncoder encoder = new StaticWordValueEncoder("body"); encoder.setProbes(2); encoder.setTraceDictionary(traceDictionary); // provides a constant offset that the model can use to encode the average // frequency // of each class FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept"); bias.setTraceDictionary(traceDictionary); int actual = newsGroups.intern(newsgroup_name); // newsGroups.values().contains(arg0) // System.out.println( "> newsgroup name: " + newsgroup_name ); // System.out.println( "> newsgroup id: " + actual ); Multiset<String> words = ConcurrentHashMultiset.create(); /* * // System.out.println("record: "); for ( int x = 1; x < parts.length; x++ * ) { //String s = ts.getAttribute(CharTermAttribute.class).toString(); // * System.out.print( " " + parts[x] ); String foo = parts[x].trim(); * System.out.print( " " + foo ); words.add( foo ); * * } // System.out.println("\nEOR"); System.out.println( "\nwords found: " + * (parts.length - 1) ); System.out.println( "words in set: " + words.size() * + ", " + words.toString() ); */ StringReader in = new StringReader(msg); countWords(analyzer, words, in); // ----- p.271 ----------- // Vector v = new RandomAccessSparseVector(FEATURES); // original value does nothing in a ContantValueEncoder bias.addToVector("", 1, v); // original value does nothing in a ContantValueEncoder // lines.addToVector("", lineCount / 30, v); // original value does nothing in a ContantValueEncoder // logLines.addToVector("", Math.log(lineCount + 1), v); // now scan through all the words and add them // System.out.println( "############### " + words.toArray().length); for (String word : words.elementSet()) { encoder.addToVector(word, Math.log(1 + words.count(word)), v); // System.out.print( words.count(word) + " " ); } // System.out.println("\nEOL\n"); return actual; }
From source file:com.cloudera.knittingboar.sgd.olr.TestBaseOLR_Train20Newsgroups.java
License:Apache License
public void testTrainNewsGroups() throws IOException { File base = new File("/Users/jpatterson/Downloads/datasets/20news-bydate/20news-bydate-train/"); overallCounts = HashMultiset.create(); long startTime = System.currentTimeMillis(); // p.269 --------------------------------------------------------- Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>(); // encodes the text content in both the subject and the body of the email FeatureVectorEncoder encoder = new StaticWordValueEncoder("body"); encoder.setProbes(2);/* w w w . ja va2 s. com*/ encoder.setTraceDictionary(traceDictionary); // provides a constant offset that the model can use to encode the average frequency // of each class FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept"); bias.setTraceDictionary(traceDictionary); // used to encode the number of lines in a message FeatureVectorEncoder lines = new ConstantValueEncoder("Lines"); lines.setTraceDictionary(traceDictionary); FeatureVectorEncoder logLines = new ConstantValueEncoder("LogLines"); logLines.setTraceDictionary(traceDictionary); Dictionary newsGroups = new Dictionary(); // matches the OLR setup on p.269 --------------- // stepOffset, decay, and alpha --- describe how the learning rate decreases // lambda: amount of regularization // learningRate: amount of initial learning rate OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression(20, FEATURES, new L1()).alpha(1) .stepOffset(1000).decayExponent(0.9).lambda(3.0e-5).learningRate(20); // bottom of p.269 ------------------------------ // because OLR expects to get integer class IDs for the target variable during training // we need a dictionary to convert the target variable (the newsgroup name) // to an integer, which is the newsGroup object List<File> files = new ArrayList<File>(); for (File newsgroup : base.listFiles()) { newsGroups.intern(newsgroup.getName()); files.addAll(Arrays.asList(newsgroup.listFiles())); } // mix up the files, helps training in OLR Collections.shuffle(files); System.out.printf("%d training files\n", files.size()); // p.270 ----- metrics to track lucene's parsing mechanics, progress, performance of OLR ------------ double averageLL = 0.0; double averageCorrect = 0.0; double averageLineCount = 0.0; int k = 0; double step = 0.0; int[] bumps = new int[] { 1, 2, 5 }; double lineCount = 0; // last line on p.269 Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31); Splitter onColon = Splitter.on(":").trimResults(); int input_file_count = 0; // ----- p.270 ------------ "reading and tokenzing the data" --------- for (File file : files) { BufferedReader reader = new BufferedReader(new FileReader(file)); input_file_count++; // identify newsgroup ---------------- // convert newsgroup name to unique id // ----------------------------------- String ng = file.getParentFile().getName(); int actual = newsGroups.intern(ng); Multiset<String> words = ConcurrentHashMultiset.create(); // check for line count header ------- String line = reader.readLine(); while (line != null && line.length() > 0) { // if this is a line that has a line count, let's pull that value out ------ if (line.startsWith("Lines:")) { String count = Iterables.get(onColon.split(line), 1); try { lineCount = Integer.parseInt(count); averageLineCount += (lineCount - averageLineCount) / Math.min(k + 1, 1000); } catch (NumberFormatException e) { // if anything goes wrong in parse: just use the avg count lineCount = averageLineCount; } } boolean countHeader = (line.startsWith("From:") || line.startsWith("Subject:") || line.startsWith("Keywords:") || line.startsWith("Summary:")); // loop through the lines in the file, while the line starts with: " " do { // get a reader for this specific string ------ StringReader in = new StringReader(line); // ---- count words in header --------- if (countHeader) { countWords(analyzer, words, in); } // iterate to the next string ---- line = reader.readLine(); } while (line.startsWith(" ")); } // while (lines in header) { // -------- count words in body ---------- countWords(analyzer, words, reader); reader.close(); // ----- p.271 ----------- Vector v = new RandomAccessSparseVector(FEATURES); // original value does nothing in a ContantValueEncoder bias.addToVector("", 1, v); // original value does nothing in a ContantValueEncoder lines.addToVector("", lineCount / 30, v); // original value does nothing in a ContantValueEncoder logLines.addToVector("", Math.log(lineCount + 1), v); // now scan through all the words and add them for (String word : words.elementSet()) { encoder.addToVector(word, Math.log(1 + words.count(word)), v); } //Utils.PrintVectorNonZero(v); // calc stats --------- double mu = Math.min(k + 1, 200); double ll = learningAlgorithm.logLikelihood(actual, v); averageLL = averageLL + (ll - averageLL) / mu; Vector p = new DenseVector(20); learningAlgorithm.classifyFull(p, v); int estimated = p.maxValueIndex(); int correct = (estimated == actual ? 1 : 0); averageCorrect = averageCorrect + (correct - averageCorrect) / mu; learningAlgorithm.train(actual, v); k++; int bump = bumps[(int) Math.floor(step) % bumps.length]; int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); if (k % (bump * scale) == 0) { step += 0.25; System.out.printf("%10d %10.3f %10.3f %10.2f %s %s\n", k, ll, averageLL, averageCorrect * 100, ng, newsGroups.values().get(estimated)); } learningAlgorithm.close(); /* if (k>4) { break; } */ } Utils.PrintVectorSection(learningAlgorithm.getBeta().viewRow(0), 3); long endTime = System.currentTimeMillis(); //System.out.println("That took " + (endTime - startTime) + " milliseconds"); long duration = (endTime - startTime); System.out.println("Processed Input Files: " + input_file_count + ", time: " + duration + "ms"); ModelSerializer.writeBinary("/tmp/olr-news-group.model", learningAlgorithm); // learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); }
From source file:org.deidentifier.arx.aggregates.classification.MultiClassLogisticRegression.java
License:Apache License
/** * Creates a new instance//from ww w .j a v a2 s . c o m * @param specification * @param config */ public MultiClassLogisticRegression(ClassificationDataSpecification specification, ARXLogisticRegressionConfiguration config) { // Store this.config = config; this.specification = specification; // Prepare classifier PriorFunction prior = null; switch (config.getPriorFunction()) { case ELASTIC_BAND: prior = new ElasticBandPrior(); break; case L1: prior = new L1(); break; case L2: prior = new L2(); break; case UNIFORM: prior = new UniformPrior(); break; default: throw new IllegalArgumentException("Unknown prior function"); } this.lr = new OnlineLogisticRegression(this.specification.classMap.size(), config.getVectorLength(), prior); // Configure this.lr.learningRate(config.getLearningRate()); this.lr.alpha(config.getAlpha()); this.lr.lambda(config.getLambda()); this.lr.stepOffset(config.getStepOffset()); this.lr.decayExponent(config.getDecayExponent()); // Prepare encoders this.interceptEncoder = new ConstantValueEncoder("intercept"); this.wordEncoder = new StaticWordValueEncoder("feature"); // Configure this.lr.learningRate(1); this.lr.alpha(1); this.lr.lambda(0.000001); this.lr.stepOffset(10000); this.lr.decayExponent(0.2); }
From source file:org.deidentifier.arx.aggregates.classification.MultiClassNaiveBayes.java
License:Apache License
/** * Creates a new instance/* w w w . ja va 2s . c om*/ * @param interrupt * @param specification * @param config * @param inputHandle */ public MultiClassNaiveBayes(WrappedBoolean interrupt, ClassificationDataSpecification specification, ClassificationConfigurationNaiveBayes config, DataHandleInternal inputHandle) { super(interrupt); // Store this.config = config; this.specification = specification; this.inputHandle = inputHandle; // Prepare classifier this.nb = new NaiveBayes(config.getType() == Type.BERNOULLI ? Model.BERNOULLI : Model.MULTINOMIAL, this.specification.classMap.size(), config.getVectorLength(), config.getSigma(), null); // Prepare encoders this.interceptEncoder = new ConstantValueEncoder("intercept"); this.wordEncoder = new StaticWordValueEncoder("feature"); }