List of usage examples for org.apache.mahout.math RandomAccessSparseVector RandomAccessSparseVector
public RandomAccessSparseVector(Vector other)
From source file:ClassifierHD.java
License:Apache License
public static void main(String[] args) throws Exception { if (args.length < 5) { System.out.println(/*from ww w .j a v a 2s . c o m*/ "Arguments: [model] [label index] [dictionnary] [document frequency] [postgres table] [hdfs dir] [job_id]"); return; } String modelPath = args[0]; String labelIndexPath = args[1]; String dictionaryPath = args[2]; String documentFrequencyPath = args[3]; String tablename = args[4]; String inputDir = args[5]; Configuration configuration = new Configuration(); // model is a matrix (wordId, labelId) => probability score NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), configuration); StandardNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model); // labels is a map label => classId Map<Integer, String> labels = BayesUtils.readLabelIndex(configuration, new Path(labelIndexPath)); Map<String, Integer> dictionary = readDictionnary(configuration, new Path(dictionaryPath)); Map<Integer, Long> documentFrequency = readDocumentFrequency(configuration, new Path(documentFrequencyPath)); // analyzer used to extract word from tweet Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_43); int labelCount = labels.size(); int documentCount = documentFrequency.get(-1).intValue(); System.out.println("Number of labels: " + labelCount); System.out.println("Number of documents in training set: " + documentCount); Connection conn = null; PreparedStatement pstmt = null; try { Class.forName("org.postgresql.Driver"); conn = DriverManager.getConnection("jdbc:postgresql://192.168.50.170:5432/uzeni", "postgres", "dbwpsdkdl"); conn.setAutoCommit(false); String sql = "INSERT INTO " + tablename + " (id,gtime,wtime,target,num,link,body,rep) VALUES (?,?,?,?,?,?,?,?);"; pstmt = conn.prepareStatement(sql); FileSystem fs = FileSystem.get(configuration); FileStatus[] status = fs.listStatus(new Path(inputDir)); BufferedWriter bw = new BufferedWriter( new OutputStreamWriter(fs.create(new Path(inputDir + "/rep.list"), true))); for (int i = 0; i < status.length; i++) { BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(status[i].getPath()))); if (new String(status[i].getPath().getName()).equals("rep.list")) { continue; } int lv_HEAD = 1; int lv_cnt = 0; String lv_gtime = null; String lv_wtime = null; String lv_target = null; BigDecimal lv_num = null; String lv_link = null; String[] lv_args; String lv_line; StringBuilder lv_txt = new StringBuilder(); while ((lv_line = br.readLine()) != null) { if (lv_cnt < lv_HEAD) { lv_args = lv_line.split(","); lv_gtime = lv_args[0]; lv_wtime = lv_args[1]; lv_target = lv_args[2]; lv_num = new BigDecimal(lv_args[3]); lv_link = lv_args[4]; } else { lv_txt.append(lv_line + '\n'); } lv_cnt++; } br.close(); String id = status[i].getPath().getName(); String message = lv_txt.toString(); Multiset<String> words = ConcurrentHashMultiset.create(); TokenStream ts = analyzer.tokenStream("text", new StringReader(message)); CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); ts.reset(); int wordCount = 0; while (ts.incrementToken()) { if (termAtt.length() > 0) { String word = ts.getAttribute(CharTermAttribute.class).toString(); Integer wordId = dictionary.get(word); if (wordId != null) { words.add(word); wordCount++; } } } ts.end(); ts.close(); Vector vector = new RandomAccessSparseVector(10000); TFIDF tfidf = new TFIDF(); for (Multiset.Entry<String> entry : words.entrySet()) { String word = entry.getElement(); int count = entry.getCount(); Integer wordId = dictionary.get(word); Long freq = documentFrequency.get(wordId); double tfIdfValue = tfidf.calculate(count, freq.intValue(), wordCount, documentCount); vector.setQuick(wordId, tfIdfValue); } Vector resultVector = classifier.classifyFull(vector); double bestScore = -Double.MAX_VALUE; int bestCategoryId = -1; for (Element element : resultVector.all()) { int categoryId = element.index(); double score = element.get(); if (score > bestScore) { bestScore = score; bestCategoryId = categoryId; } } //System.out.println(message); //System.out.println(" => "+ lv_gtime + lv_wtime + lv_link + id + ":" + labels.get(bestCategoryId)); pstmt.setString(1, id); pstmt.setString(2, lv_gtime); pstmt.setString(3, lv_wtime); pstmt.setString(4, lv_target); pstmt.setBigDecimal(5, lv_num); pstmt.setString(6, lv_link); pstmt.setString(7, message.substring(1, Math.min(50, message.length()))); pstmt.setString(8, labels.get(bestCategoryId)); pstmt.addBatch(); bw.write(id + "\t" + labels.get(bestCategoryId) + "\n"); } pstmt.executeBatch(); //pstmt.clearParameters(); pstmt.close(); conn.commit(); conn.close(); bw.close(); } catch (Exception e) { System.err.println(e.getClass().getName() + ": " + e.getMessage()); System.exit(0); } analyzer.close(); }
From source file:PostgresClassifier.java
License:Apache License
public static void main(String[] args) throws Exception { if (args.length < 5) { System.out.println(//w w w .j a va 2 s . c o m "Arguments: [model] [label index] [dictionnary] [document frequency] [input postgres table]"); return; } String modelPath = args[0]; String labelIndexPath = args[1]; String dictionaryPath = args[2]; String documentFrequencyPath = args[3]; String tablename = args[4]; Configuration configuration = new Configuration(); // model is a matrix (wordId, labelId) => probability score NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), configuration); StandardNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model); // labels is a map label => classId Map<Integer, String> labels = BayesUtils.readLabelIndex(configuration, new Path(labelIndexPath)); Map<String, Integer> dictionary = readDictionnary(configuration, new Path(dictionaryPath)); Map<Integer, Long> documentFrequency = readDocumentFrequency(configuration, new Path(documentFrequencyPath)); // analyzer used to extract word from tweet Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_43); int labelCount = labels.size(); int documentCount = documentFrequency.get(-1).intValue(); System.out.println("Number of labels: " + labelCount); System.out.println("Number of documents in training set: " + documentCount); Connection c = null; Statement stmt = null; Statement stmtU = null; try { Class.forName("org.postgresql.Driver"); c = DriverManager.getConnection("jdbc:postgresql://192.168.50.170:5432/uzeni", "postgres", "dbwpsdkdl"); c.setAutoCommit(false); System.out.println("Opened database successfully"); stmt = c.createStatement(); stmtU = c.createStatement(); ResultSet rs = stmt.executeQuery("SELECT * FROM " + tablename + " WHERE rep is null"); while (rs.next()) { String seq = rs.getString("seq"); //String rep = rs.getString("rep"); String body = rs.getString("body"); //String category = rep; String id = seq; String message = body; //System.out.println("Doc: " + id + "\t" + message); Multiset<String> words = ConcurrentHashMultiset.create(); // extract words from tweet TokenStream ts = analyzer.tokenStream("text", new StringReader(message)); CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); ts.reset(); int wordCount = 0; while (ts.incrementToken()) { if (termAtt.length() > 0) { String word = ts.getAttribute(CharTermAttribute.class).toString(); Integer wordId = dictionary.get(word); // if the word is not in the dictionary, skip it if (wordId != null) { words.add(word); wordCount++; } } } // Mark : Modified ts.end(); ts.close(); // create vector wordId => weight using tfidf Vector vector = new RandomAccessSparseVector(10000); TFIDF tfidf = new TFIDF(); for (Multiset.Entry<String> entry : words.entrySet()) { String word = entry.getElement(); int count = entry.getCount(); Integer wordId = dictionary.get(word); Long freq = documentFrequency.get(wordId); double tfIdfValue = tfidf.calculate(count, freq.intValue(), wordCount, documentCount); vector.setQuick(wordId, tfIdfValue); } // With the classifier, we get one score for each label // The label with the highest score is the one the tweet is more likely to // be associated to Vector resultVector = classifier.classifyFull(vector); double bestScore = -Double.MAX_VALUE; int bestCategoryId = -1; for (Element element : resultVector.all()) { int categoryId = element.index(); double score = element.get(); if (score > bestScore) { bestScore = score; bestCategoryId = categoryId; } //System.out.print(" " + labels.get(categoryId) + ": " + score); } //System.out.println(" => " + labels.get(bestCategoryId)); //System.out.println("UPDATE " + tablename + " SET rep = '" + labels.get(bestCategoryId) + "' WHERE seq = " + id ); stmtU.executeUpdate("UPDATE " + tablename + " SET rep = '" + labels.get(bestCategoryId) + "' WHERE seq = " + id); } rs.close(); stmt.close(); stmtU.close(); c.commit(); c.close(); analyzer.close(); } catch (Exception e) { System.err.println(e.getClass().getName() + ": " + e.getMessage()); System.exit(0); } }
From source file:TrainLogistic.java
License:Apache License
static void mainToOutput(String[] args, PrintWriter output) throws Exception { if (parseArgs(args)) { double logPEstimate = 0; int samples = 0; /*read files in dir of inputFile*/ int fi = 0;//file ID File file = new File(inputFile); String[] fns = file.list(new FilenameFilter() { public boolean accept(File dir, String name) { if (name.endsWith(".svm")) { return true; } else { return false; }/* w w w .j a va 2s . c o m*/ } }); String[] ss = new String[lmp.getNumFeatures() + 1]; String[] iv = new String[2]; OnlineLogisticRegression lr = lmp.createRegression(); while (fi < fns.length) { for (int pass = 0; pass < passes; pass++) { BufferedReader in = open(inputFile + fns[fi]); System.out.println(pass + 1); try { // read variable names String line = in.readLine(); int lineCount = 1; while (line != null) { // for each new line, get target and predictors Vector input = new RandomAccessSparseVector(lmp.getNumFeatures()); ss = line.split(" "); int targetValue; if (ss[0].startsWith("+")) targetValue = 1; else targetValue = 0; int k = 1; while (k < ss.length) { iv = ss[k].split(":"); input.setQuick(Integer.valueOf(iv[0]) - 1, Double.valueOf(iv[1])); //System.out.printf("%d-----%d:%.4f====%d\n", k,Integer.valueOf(iv[0])-1,Double.valueOf(iv[1]),lineCount); k++; } input.setQuick(lmp.getNumFeatures() - 1, 1); // check performance while this is still news double logP = lr.logLikelihood(targetValue, input); if (!Double.isInfinite(logP)) { if (samples < 20) { logPEstimate = (samples * logPEstimate + logP) / (samples + 1); } else { logPEstimate = 0.95 * logPEstimate + 0.05 * logP; } samples++; } double p = lr.classifyScalar(input); if (scores) { output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f\n", samples, targetValue, lr.currentLearningRate(), p, logP, logPEstimate); } // now update model lr.train(targetValue, input); if ((lineCount) % 1000 == 0) System.out.printf("%d\t", lineCount); line = in.readLine(); lineCount++; } } finally { Closeables.closeQuietly(in); } System.out.println(); } fi++; } FileOutputStream modelOutput = new FileOutputStream(outputFile); try { saveTo(modelOutput, lr); } finally { Closeables.closeQuietly(modelOutput); } /* output.printf(Locale.ENGLISH, "%d\n", lmp.getNumFeatures()); output.printf(Locale.ENGLISH, "%s ~ ", lmp.getTargetVariable()); String sep = ""; for (String v : csv.getTraceDictionary().keySet()) { double weight = predictorWeight(lr, 0, csv, v); if (weight != 0) { output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v); sep = " + "; } } output.printf("\n"); model = lr; for (int row = 0; row < lr.getBeta().numRows(); row++) { for (String key : csv.getTraceDictionary().keySet()) { double weight = predictorWeight(lr, row, csv, key); if (weight != 0) { output.printf(Locale.ENGLISH, "%20s %.5f\n", key, weight); } } for (int column = 0; column < lr.getBeta().numCols(); column++) { output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column)); } output.println(); }*/ } }
From source file:at.illecker.hama.rootbeer.examples.matrixmultiplication.compositeinput.cpu.MatrixMultiplicationBSPCpu.java
License:Apache License
@Override public void cleanup(BSPPeer<IntWritable, TupleWritable, IntWritable, VectorWritable, MatrixRowMessage> peer) throws IOException { // MasterTask accumulates result if (peer.getPeerName().equals(masterTask)) { // SortedMap because the final matrix rows should be in order SortedMap<Integer, Vector> accumlatedRows = new TreeMap<Integer, Vector>(); MatrixRowMessage currentMatrixRowMessage = null; // Collect messages while ((currentMatrixRowMessage = peer.getCurrentMessage()) != null) { int rowIndex = currentMatrixRowMessage.getRowIndex(); Vector rowValues = currentMatrixRowMessage.getRowValues().get(); if (isDebuggingEnabled) { logger.writeChars("bsp,gotMsg,key=" + rowIndex + ",value=" + rowValues.toString() + "\n"); }/* w w w . ja va2 s. c om*/ if (accumlatedRows.containsKey(rowIndex)) { accumlatedRows.get(rowIndex).assign(rowValues, Functions.PLUS); } else { accumlatedRows.put(rowIndex, new RandomAccessSparseVector(rowValues)); } } // Write accumulated results for (Map.Entry<Integer, Vector> row : accumlatedRows.entrySet()) { if (isDebuggingEnabled) { logger.writeChars( "bsp,write,key=" + row.getKey() + ",value=" + row.getValue().toString() + "\n"); } peer.write(new IntWritable(row.getKey()), new VectorWritable(row.getValue())); } } }
From source file:ca.uwaterloo.cpami.mahout.matrix.utils.GramSchmidt.java
License:Apache License
public static void storeSparseColumns(Matrix mat) { int numCols = mat.numCols(); int numRows = mat.numRows(); for (int i = 0; i < numCols; i++) { Vector sparseVect = new RandomAccessSparseVector(numRows); Vector col = mat.viewColumn(i); Iterator<Vector.Element> itr = col.iterateNonZero(); while (itr.hasNext()) { Element elem = itr.next(); if (elem.get() != 0) { System.out.println(elem.get()); sparseVect.set(elem.index(), elem.get()); }//from w w w . j a v a 2 s . c om } System.out.println(sparseVect.getNumNondefaultElements()); mat.assignColumn(i, sparseVect); System.out.println(mat.viewColumn(i).getNumNondefaultElements()); System.exit(1); } }
From source file:cc.recommenders.mining.calls.clustering.VectorBuilder.java
License:Open Source License
public List<Vector> build(List<List<Feature>> usages, Dictionary<Feature> dictionary) { List<Vector> vectors = Lists.newArrayList(); for (List<Feature> usage : usages) { final Vector vector = new RandomAccessSparseVector(dictionary.size()); for (Feature f : usage) { int index = dictionary.getId(f); boolean isValidFeature = index >= 0; if (isValidFeature) { double value = weighter.getWeight(f); vector.set(index, value); }// w ww. j av a 2 s.c om } vectors.add(vector); } return vectors; }
From source file:cc.recommenders.mining.calls.clustering.VectorBuilderTest.java
License:Open Source License
private Vector createVector(double... values) { Vector v = new RandomAccessSparseVector(4); for (int i = 0; i < values.length; i++) { v.set(i, values[i]);//www. j a va 2 s . c om } return v; }
From source file:com.chimpler.example.bayes.Classifier.java
License:Apache License
public static void main(String[] args) throws Exception { if (args.length < 5) { System.out.println("Arguments: [model] [label index] [dictionnary] [document frequency] [tweet file]"); return;//from w ww. j ava 2s . c om } String modelPath = args[0]; String labelIndexPath = args[1]; String dictionaryPath = args[2]; String documentFrequencyPath = args[3]; String tweetsPath = args[4]; Configuration configuration = new Configuration(); // model is a matrix (wordId, labelId) => probability score NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), configuration); StandardNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model); // labels is a map label => classId Map<Integer, String> labels = BayesUtils.readLabelIndex(configuration, new Path(labelIndexPath)); Map<String, Integer> dictionary = readDictionnary(configuration, new Path(dictionaryPath)); Map<Integer, Long> documentFrequency = readDocumentFrequency(configuration, new Path(documentFrequencyPath)); // analyzer used to extract word from tweet Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_43); int labelCount = labels.size(); int documentCount = documentFrequency.get(-1).intValue(); System.out.println("Number of labels: " + labelCount); System.out.println("Number of documents in training set: " + documentCount); BufferedReader reader = new BufferedReader(new FileReader(tweetsPath)); while (true) { String line = reader.readLine(); if (line == null) { break; } String[] tokens = line.split("\t", 2); String tweetId = tokens[0]; String tweet = tokens[1]; System.out.println("Tweet: " + tweetId + "\t" + tweet); Multiset<String> words = ConcurrentHashMultiset.create(); // extract words from tweet TokenStream ts = analyzer.tokenStream("text", new StringReader(tweet)); CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); ts.reset(); int wordCount = 0; while (ts.incrementToken()) { if (termAtt.length() > 0) { String word = ts.getAttribute(CharTermAttribute.class).toString(); Integer wordId = dictionary.get(word); // if the word is not in the dictionary, skip it if (wordId != null) { words.add(word); wordCount++; } } } // create vector wordId => weight using tfidf Vector vector = new RandomAccessSparseVector(10000); TFIDF tfidf = new TFIDF(); for (Multiset.Entry<String> entry : words.entrySet()) { String word = entry.getElement(); int count = entry.getCount(); Integer wordId = dictionary.get(word); Long freq = documentFrequency.get(wordId); double tfIdfValue = tfidf.calculate(count, freq.intValue(), wordCount, documentCount); vector.setQuick(wordId, tfIdfValue); } // With the classifier, we get one score for each label // The label with the highest score is the one the tweet is more likely to // be associated to Vector resultVector = classifier.classifyFull(vector); double bestScore = -Double.MAX_VALUE; int bestCategoryId = -1; for (Element element : resultVector.all()) { int categoryId = element.index(); double score = element.get(); if (score > bestScore) { bestScore = score; bestCategoryId = categoryId; } System.out.print(" " + labels.get(categoryId) + ": " + score); } System.out.println(" => " + labels.get(bestCategoryId)); } analyzer.close(); reader.close(); }
From source file:com.cloudera.knittingboar.metrics.POLRModelTester.java
License:Apache License
/** * Runs the next training batch to prep the gamma buffer to send to the * mstr_node//w w w.ja v a 2 s .c o m * * TODO: need to provide stats, group measurements into struct * * @throws Exception * @throws IOException */ public void RunThroughTestRecords() throws IOException, Exception { Text value = new Text(); long batch_vec_factory_time = 0; k = 0; int num_correct = 0; for (int x = 0; x < this.BatchSize; x++) { if (this.input_split.next(value)) { long startTime = System.currentTimeMillis(); Vector v = new RandomAccessSparseVector(this.FeatureVectorSize); int actual = this.VectorFactory.processLine(value.toString(), v); long endTime = System.currentTimeMillis(); // System.out.println("That took " + (endTime - startTime) + // " milliseconds"); batch_vec_factory_time += (endTime - startTime); String ng = this.VectorFactory.GetClassnameByID(actual); // .GetNewsgroupNameByID( // actual ); // calc stats --------- double mu = Math.min(k + 1, 200); double ll = this.polr.logLikelihood(actual, v); if (Double.isNaN(ll)) { /* * System.out.println(" --------- NaN -----------"); * * System.out.println( "k: " + k ); System.out.println( "ll: " + ll ); * System.out.println( "mu: " + mu ); */ // return; } else { metrics.AvgLogLikelihood = metrics.AvgLogLikelihood + (ll - metrics.AvgLogLikelihood) / mu; } Vector p = new DenseVector(20); this.polr.classifyFull(p, v); int estimated = p.maxValueIndex(); int correct = (estimated == actual ? 1 : 0); if (estimated == actual) { num_correct++; } // averageCorrect = averageCorrect + (correct - averageCorrect) / mu; metrics.AvgCorrect = metrics.AvgCorrect + (correct - metrics.AvgCorrect) / mu; // this.polr.train(actual, v); k++; // if (x == this.BatchSize - 1) { int bump = bumps[(int) Math.floor(step) % bumps.length]; int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); if (k % (bump * scale) == 0) { step += 0.25; System.out.printf( "Worker %s:\t Trained Recs: %10d, numCorrect: %d, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n", this.internalID, k, num_correct, metrics.AvgLogLikelihood, metrics.AvgCorrect * 100, batch_vec_factory_time); } this.polr.close(); } else { // nothing else to process in split! break; } // if } // for the number of passes in the run }
From source file:com.cloudera.knittingboar.records.RCV1RecordFactory.java
License:Apache License
public static void ScanFile(String file, int debug_break_cnt) throws IOException { ConstantValueEncoder encoder_test = new ConstantValueEncoder("test"); BufferedReader reader = null; // Collection<String> words int line_count = 0; Multiset<String> class_count = ConcurrentHashMultiset.create(); Multiset<String> namespaces = ConcurrentHashMultiset.create(); try {// w w w.j av a 2 s . c o m // System.out.println( newsgroup ); reader = new BufferedReader(new FileReader(file)); String line = reader.readLine(); while (line != null && line.length() > 0) { // shard_writer.write(line + "\n"); // out += line; String[] parts = line.split(" "); // System.out.println( "Class: " + parts[0] ); class_count.add(parts[0]); namespaces.add(parts[1]); line = reader.readLine(); line_count++; Vector v = new RandomAccessSparseVector(FEATURES); for (int x = 2; x < parts.length; x++) { // encoder_test.addToVector(parts[x], v); // System.out.println( parts[x] ); String[] feature = parts[x].split(":"); int index = Integer.parseInt(feature[0]) % FEATURES; double val = Double.parseDouble(feature[1]); // System.out.println( feature[1] + " = " + val ); if (index < FEATURES) { v.set(index, val); } else { System.out.println("Could Hash: " + index + " to " + (index % FEATURES)); } } Utils.PrintVectorSectionNonZero(v, 10); System.out.println("###"); if (line_count > debug_break_cnt) { break; } } System.out.println("Total Rec Count: " + line_count); System.out.println("-------------------- "); System.out.println("Classes"); for (String word : class_count.elementSet()) { System.out.println("Class " + word + ": " + class_count.count(word) + " "); } System.out.println("-------------------- "); System.out.println("NameSpaces:"); for (String word : namespaces.elementSet()) { System.out.println("Namespace " + word + ": " + namespaces.count(word) + " "); } /* * TokenStream ts = analyzer.tokenStream("text", reader); * ts.addAttribute(CharTermAttribute.class); * * // for each word in the stream, minus non-word stuff, add word to * collection while (ts.incrementToken()) { String s = * ts.getAttribute(CharTermAttribute.class).toString(); * //System.out.print( " " + s ); //words.add(s); out += s + " "; } */ } finally { reader.close(); } // return out + "\n"; }