Example usage for org.apache.mahout.math Vector setQuick

List of usage examples for org.apache.mahout.math Vector setQuick

Introduction

In this page you can find the example usage for org.apache.mahout.math Vector setQuick.

Prototype

void setQuick(int index, double value);

Source Link

Document

Set the value at the given index, without checking bounds

Usage

From source file:Vectors.java

License:Apache License

public static Vector maybeSample(Vector original, int sampleSize) {
    if (original.getNumNondefaultElements() <= sampleSize) {
        return original;
    }/* w w  w  . j a  v a2s. co m*/
    Vector sample = original.like();
    Iterator<Vector.Element> sampledElements = new FixedSizeSamplingIterator<Vector.Element>(sampleSize,
            original.iterateNonZero());
    while (sampledElements.hasNext()) {
        Vector.Element elem = sampledElements.next();
        sample.setQuick(elem.index(), elem.get());
    }
    return sample;
}

From source file:Vectors.java

License:Apache License

public static Vector topKElements(int k, Vector original) {
    if (original.getNumNondefaultElements() <= k) {
        return original;
    }//from   ww w  . ja va  2 s.  c  om
    TopK<Vector.Element> topKQueue = new TopK<Vector.Element>(k, BY_VALUE);
    Iterator<Vector.Element> nonZeroElements = original.iterateNonZero();
    while (nonZeroElements.hasNext()) {
        Vector.Element nonZeroElement = nonZeroElements.next();
        topKQueue.offer(new Vectors.TemporaryElement(nonZeroElement));
    }
    Vector topKSimilarities = original.like();
    for (Vector.Element topKSimilarity : topKQueue.retrieve()) {
        topKSimilarities.setQuick(topKSimilarity.index(), topKSimilarity.get());
    }
    return topKSimilarities;
}

From source file:Vectors.java

License:Apache License

public static Vector merge(Iterable<VectorWritable> partialVectors) {
    Iterator<VectorWritable> vectors = partialVectors.iterator();
    Vector accumulator = vectors.next().get();
    while (vectors.hasNext()) {
        VectorWritable v = vectors.next();
        if (v != null) {
            Iterator<Vector.Element> nonZeroElements = v.get().iterateNonZero();
            while (nonZeroElements.hasNext()) {
                Vector.Element nonZeroElement = nonZeroElements.next();
                accumulator.setQuick(nonZeroElement.index(), nonZeroElement.get());
            }/*from  www .  j  av  a  2s  . co m*/
        }
    }
    return accumulator;
}

From source file:ClassifierHD.java

License:Apache License

public static void main(String[] args) throws Exception {
    if (args.length < 5) {
        System.out.println(/*from   w  w  w  .  j a v  a 2s  . c  o  m*/
                "Arguments: [model] [label index] [dictionnary] [document frequency] [postgres table] [hdfs dir] [job_id]");
        return;
    }
    String modelPath = args[0];
    String labelIndexPath = args[1];
    String dictionaryPath = args[2];
    String documentFrequencyPath = args[3];
    String tablename = args[4];
    String inputDir = args[5];

    Configuration configuration = new Configuration();

    // model is a matrix (wordId, labelId) => probability score
    NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), configuration);

    StandardNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model);

    // labels is a map label => classId
    Map<Integer, String> labels = BayesUtils.readLabelIndex(configuration, new Path(labelIndexPath));
    Map<String, Integer> dictionary = readDictionnary(configuration, new Path(dictionaryPath));
    Map<Integer, Long> documentFrequency = readDocumentFrequency(configuration,
            new Path(documentFrequencyPath));

    // analyzer used to extract word from tweet
    Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_43);

    int labelCount = labels.size();
    int documentCount = documentFrequency.get(-1).intValue();

    System.out.println("Number of labels: " + labelCount);
    System.out.println("Number of documents in training set: " + documentCount);

    Connection conn = null;
    PreparedStatement pstmt = null;

    try {
        Class.forName("org.postgresql.Driver");
        conn = DriverManager.getConnection("jdbc:postgresql://192.168.50.170:5432/uzeni", "postgres",
                "dbwpsdkdl");
        conn.setAutoCommit(false);
        String sql = "INSERT INTO " + tablename
                + " (id,gtime,wtime,target,num,link,body,rep) VALUES (?,?,?,?,?,?,?,?);";
        pstmt = conn.prepareStatement(sql);

        FileSystem fs = FileSystem.get(configuration);
        FileStatus[] status = fs.listStatus(new Path(inputDir));
        BufferedWriter bw = new BufferedWriter(
                new OutputStreamWriter(fs.create(new Path(inputDir + "/rep.list"), true)));

        for (int i = 0; i < status.length; i++) {
            BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(status[i].getPath())));
            if (new String(status[i].getPath().getName()).equals("rep.list")) {
                continue;
            }
            int lv_HEAD = 1;
            int lv_cnt = 0;
            String lv_gtime = null;
            String lv_wtime = null;
            String lv_target = null;
            BigDecimal lv_num = null;
            String lv_link = null;
            String[] lv_args;
            String lv_line;
            StringBuilder lv_txt = new StringBuilder();
            while ((lv_line = br.readLine()) != null) {
                if (lv_cnt < lv_HEAD) {
                    lv_args = lv_line.split(",");
                    lv_gtime = lv_args[0];
                    lv_wtime = lv_args[1];
                    lv_target = lv_args[2];
                    lv_num = new BigDecimal(lv_args[3]);
                    lv_link = lv_args[4];
                } else {
                    lv_txt.append(lv_line + '\n');
                }
                lv_cnt++;
            }
            br.close();

            String id = status[i].getPath().getName();
            String message = lv_txt.toString();

            Multiset<String> words = ConcurrentHashMultiset.create();

            TokenStream ts = analyzer.tokenStream("text", new StringReader(message));
            CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);
            ts.reset();
            int wordCount = 0;
            while (ts.incrementToken()) {
                if (termAtt.length() > 0) {
                    String word = ts.getAttribute(CharTermAttribute.class).toString();
                    Integer wordId = dictionary.get(word);
                    if (wordId != null) {
                        words.add(word);
                        wordCount++;
                    }
                }
            }

            ts.end();
            ts.close();

            Vector vector = new RandomAccessSparseVector(10000);
            TFIDF tfidf = new TFIDF();
            for (Multiset.Entry<String> entry : words.entrySet()) {
                String word = entry.getElement();
                int count = entry.getCount();
                Integer wordId = dictionary.get(word);
                Long freq = documentFrequency.get(wordId);
                double tfIdfValue = tfidf.calculate(count, freq.intValue(), wordCount, documentCount);
                vector.setQuick(wordId, tfIdfValue);
            }
            Vector resultVector = classifier.classifyFull(vector);
            double bestScore = -Double.MAX_VALUE;
            int bestCategoryId = -1;
            for (Element element : resultVector.all()) {
                int categoryId = element.index();
                double score = element.get();
                if (score > bestScore) {
                    bestScore = score;
                    bestCategoryId = categoryId;
                }
            }
            //System.out.println(message);
            //System.out.println(" => "+ lv_gtime + lv_wtime + lv_link + id + ":" + labels.get(bestCategoryId));
            pstmt.setString(1, id);
            pstmt.setString(2, lv_gtime);
            pstmt.setString(3, lv_wtime);
            pstmt.setString(4, lv_target);
            pstmt.setBigDecimal(5, lv_num);
            pstmt.setString(6, lv_link);
            pstmt.setString(7, message.substring(1, Math.min(50, message.length())));
            pstmt.setString(8, labels.get(bestCategoryId));
            pstmt.addBatch();
            bw.write(id + "\t" + labels.get(bestCategoryId) + "\n");
        }
        pstmt.executeBatch();
        //pstmt.clearParameters();
        pstmt.close();
        conn.commit();
        conn.close();
        bw.close();
    } catch (Exception e) {
        System.err.println(e.getClass().getName() + ": " + e.getMessage());
        System.exit(0);
    }
    analyzer.close();
}

From source file:PostgresClassifier.java

License:Apache License

public static void main(String[] args) throws Exception {
    if (args.length < 5) {
        System.out.println(//from  w  w  w.  jav a2  s. co  m
                "Arguments: [model] [label index] [dictionnary] [document frequency] [input postgres table]");
        return;
    }
    String modelPath = args[0];
    String labelIndexPath = args[1];
    String dictionaryPath = args[2];
    String documentFrequencyPath = args[3];
    String tablename = args[4];

    Configuration configuration = new Configuration();

    // model is a matrix (wordId, labelId) => probability score
    NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), configuration);

    StandardNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model);

    // labels is a map label => classId
    Map<Integer, String> labels = BayesUtils.readLabelIndex(configuration, new Path(labelIndexPath));
    Map<String, Integer> dictionary = readDictionnary(configuration, new Path(dictionaryPath));
    Map<Integer, Long> documentFrequency = readDocumentFrequency(configuration,
            new Path(documentFrequencyPath));

    // analyzer used to extract word from tweet
    Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_43);

    int labelCount = labels.size();
    int documentCount = documentFrequency.get(-1).intValue();

    System.out.println("Number of labels: " + labelCount);
    System.out.println("Number of documents in training set: " + documentCount);

    Connection c = null;
    Statement stmt = null;
    Statement stmtU = null;
    try {
        Class.forName("org.postgresql.Driver");
        c = DriverManager.getConnection("jdbc:postgresql://192.168.50.170:5432/uzeni", "postgres", "dbwpsdkdl");
        c.setAutoCommit(false);
        System.out.println("Opened database successfully");
        stmt = c.createStatement();
        stmtU = c.createStatement();
        ResultSet rs = stmt.executeQuery("SELECT * FROM " + tablename + " WHERE rep is null");

        while (rs.next()) {
            String seq = rs.getString("seq");
            //String rep = rs.getString("rep");
            String body = rs.getString("body");
            //String category = rep;
            String id = seq;
            String message = body;

            //System.out.println("Doc: " + id + "\t" + message);

            Multiset<String> words = ConcurrentHashMultiset.create();

            // extract words from tweet
            TokenStream ts = analyzer.tokenStream("text", new StringReader(message));
            CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);
            ts.reset();
            int wordCount = 0;
            while (ts.incrementToken()) {
                if (termAtt.length() > 0) {
                    String word = ts.getAttribute(CharTermAttribute.class).toString();
                    Integer wordId = dictionary.get(word);
                    // if the word is not in the dictionary, skip it
                    if (wordId != null) {
                        words.add(word);
                        wordCount++;
                    }
                }
            }
            // Mark : Modified 
            ts.end();
            ts.close();

            // create vector wordId => weight using tfidf
            Vector vector = new RandomAccessSparseVector(10000);
            TFIDF tfidf = new TFIDF();
            for (Multiset.Entry<String> entry : words.entrySet()) {
                String word = entry.getElement();
                int count = entry.getCount();
                Integer wordId = dictionary.get(word);
                Long freq = documentFrequency.get(wordId);
                double tfIdfValue = tfidf.calculate(count, freq.intValue(), wordCount, documentCount);
                vector.setQuick(wordId, tfIdfValue);
            }
            // With the classifier, we get one score for each label 
            // The label with the highest score is the one the tweet is more likely to
            // be associated to
            Vector resultVector = classifier.classifyFull(vector);
            double bestScore = -Double.MAX_VALUE;
            int bestCategoryId = -1;
            for (Element element : resultVector.all()) {
                int categoryId = element.index();
                double score = element.get();
                if (score > bestScore) {
                    bestScore = score;
                    bestCategoryId = categoryId;
                }
                //System.out.print("  " + labels.get(categoryId) + ": " + score);
            }
            //System.out.println(" => " + labels.get(bestCategoryId));
            //System.out.println("UPDATE " + tablename + " SET rep = '" + labels.get(bestCategoryId) + "' WHERE seq = " + id );
            stmtU.executeUpdate("UPDATE " + tablename + " SET rep = '" + labels.get(bestCategoryId)
                    + "' WHERE seq = " + id);
        }
        rs.close();
        stmt.close();
        stmtU.close();
        c.commit();
        c.close();
        analyzer.close();
    } catch (Exception e) {
        System.err.println(e.getClass().getName() + ": " + e.getMessage());
        System.exit(0);
    }
}

From source file:TrainLogistic.java

License:Apache License

static void mainToOutput(String[] args, PrintWriter output) throws Exception {
    if (parseArgs(args)) {
        double logPEstimate = 0;
        int samples = 0;
        /*read files in dir of inputFile*/
        int fi = 0;//file ID
        File file = new File(inputFile);
        String[] fns = file.list(new FilenameFilter() {
            public boolean accept(File dir, String name) {
                if (name.endsWith(".svm")) {
                    return true;
                } else {
                    return false;
                }/*from w ww .j  a  va  2 s.c  o  m*/
            }
        });

        String[] ss = new String[lmp.getNumFeatures() + 1];
        String[] iv = new String[2];
        OnlineLogisticRegression lr = lmp.createRegression();
        while (fi < fns.length) {
            for (int pass = 0; pass < passes; pass++) {
                BufferedReader in = open(inputFile + fns[fi]);
                System.out.println(pass + 1);
                try {
                    // read variable names

                    String line = in.readLine();
                    int lineCount = 1;
                    while (line != null) {
                        // for each new line, get target and predictors
                        Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                        ss = line.split(" ");
                        int targetValue;
                        if (ss[0].startsWith("+"))
                            targetValue = 1;
                        else
                            targetValue = 0;
                        int k = 1;
                        while (k < ss.length) {
                            iv = ss[k].split(":");
                            input.setQuick(Integer.valueOf(iv[0]) - 1, Double.valueOf(iv[1]));
                            //System.out.printf("%d-----%d:%.4f====%d\n", k,Integer.valueOf(iv[0])-1,Double.valueOf(iv[1]),lineCount);
                            k++;
                        }
                        input.setQuick(lmp.getNumFeatures() - 1, 1);
                        // check performance while this is still news
                        double logP = lr.logLikelihood(targetValue, input);
                        if (!Double.isInfinite(logP)) {
                            if (samples < 20) {
                                logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
                            } else {
                                logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
                            }
                            samples++;
                        }
                        double p = lr.classifyScalar(input);
                        if (scores) {
                            output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f\n", samples,
                                    targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
                        }
                        // now update model
                        lr.train(targetValue, input);
                        if ((lineCount) % 1000 == 0)
                            System.out.printf("%d\t", lineCount);
                        line = in.readLine();
                        lineCount++;
                    }
                } finally {
                    Closeables.closeQuietly(in);
                }
                System.out.println();
            }
            fi++;
        }

        FileOutputStream modelOutput = new FileOutputStream(outputFile);
        try {
            saveTo(modelOutput, lr);
        } finally {
            Closeables.closeQuietly(modelOutput);
        }
        /*
              output.printf(Locale.ENGLISH, "%d\n", lmp.getNumFeatures());
              output.printf(Locale.ENGLISH, "%s ~ ", lmp.getTargetVariable());
              String sep = "";
              for (String v : csv.getTraceDictionary().keySet()) {
                double weight = predictorWeight(lr, 0, csv, v);
                if (weight != 0) {
                  output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
                  sep = " + ";
                }
              }
              output.printf("\n");
              model = lr;
              for (int row = 0; row < lr.getBeta().numRows(); row++) {
                for (String key : csv.getTraceDictionary().keySet()) {
                  double weight = predictorWeight(lr, row, csv, key);
                  if (weight != 0) {
                    output.printf(Locale.ENGLISH, "%20s %.5f\n", key, weight);
                  }
                }
                for (int column = 0; column < lr.getBeta().numCols(); column++) {
                  output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
                }
                output.println();
              }*/
    }
}

From source file:com.chimpler.example.bayes.Classifier.java

License:Apache License

public static void main(String[] args) throws Exception {
    if (args.length < 5) {
        System.out.println("Arguments: [model] [label index] [dictionnary] [document frequency] [tweet file]");
        return;//from   w  w w .  ja va  2  s. co  m
    }
    String modelPath = args[0];
    String labelIndexPath = args[1];
    String dictionaryPath = args[2];
    String documentFrequencyPath = args[3];
    String tweetsPath = args[4];

    Configuration configuration = new Configuration();

    // model is a matrix (wordId, labelId) => probability score
    NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), configuration);

    StandardNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model);

    // labels is a map label => classId
    Map<Integer, String> labels = BayesUtils.readLabelIndex(configuration, new Path(labelIndexPath));
    Map<String, Integer> dictionary = readDictionnary(configuration, new Path(dictionaryPath));
    Map<Integer, Long> documentFrequency = readDocumentFrequency(configuration,
            new Path(documentFrequencyPath));

    // analyzer used to extract word from tweet
    Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_43);

    int labelCount = labels.size();
    int documentCount = documentFrequency.get(-1).intValue();

    System.out.println("Number of labels: " + labelCount);
    System.out.println("Number of documents in training set: " + documentCount);
    BufferedReader reader = new BufferedReader(new FileReader(tweetsPath));
    while (true) {
        String line = reader.readLine();
        if (line == null) {
            break;
        }

        String[] tokens = line.split("\t", 2);
        String tweetId = tokens[0];
        String tweet = tokens[1];

        System.out.println("Tweet: " + tweetId + "\t" + tweet);

        Multiset<String> words = ConcurrentHashMultiset.create();

        // extract words from tweet
        TokenStream ts = analyzer.tokenStream("text", new StringReader(tweet));
        CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);
        ts.reset();
        int wordCount = 0;
        while (ts.incrementToken()) {
            if (termAtt.length() > 0) {
                String word = ts.getAttribute(CharTermAttribute.class).toString();
                Integer wordId = dictionary.get(word);
                // if the word is not in the dictionary, skip it
                if (wordId != null) {
                    words.add(word);
                    wordCount++;
                }
            }
        }

        // create vector wordId => weight using tfidf
        Vector vector = new RandomAccessSparseVector(10000);
        TFIDF tfidf = new TFIDF();
        for (Multiset.Entry<String> entry : words.entrySet()) {
            String word = entry.getElement();
            int count = entry.getCount();
            Integer wordId = dictionary.get(word);
            Long freq = documentFrequency.get(wordId);
            double tfIdfValue = tfidf.calculate(count, freq.intValue(), wordCount, documentCount);
            vector.setQuick(wordId, tfIdfValue);
        }
        // With the classifier, we get one score for each label 
        // The label with the highest score is the one the tweet is more likely to
        // be associated to
        Vector resultVector = classifier.classifyFull(vector);
        double bestScore = -Double.MAX_VALUE;
        int bestCategoryId = -1;
        for (Element element : resultVector.all()) {
            int categoryId = element.index();
            double score = element.get();
            if (score > bestScore) {
                bestScore = score;
                bestCategoryId = categoryId;
            }
            System.out.print("  " + labels.get(categoryId) + ": " + score);
        }
        System.out.println(" => " + labels.get(bestCategoryId));
    }
    analyzer.close();
    reader.close();
}

From source file:com.elex.dmp.core.TopicModel.java

License:Apache License

public void trainDocTopicModel(Vector original, Vector topics, Matrix docTopicModel) {
    // first calculate p(topic|term,document) for all terms in original, and all topics,
    // using p(term|topic) and p(topic|doc)
    pTopicGivenTerm(original, topics, docTopicModel);
    normalizeByTopic(docTopicModel);//  w w  w . j a v  a2s  . co  m
    // now multiply, term-by-term, by the document, to get the weighted distribution of
    // term-topic pairs from this document.
    Iterator<Vector.Element> it = original.iterateNonZero();
    while (it.hasNext()) {
        Vector.Element e = it.next();
        for (int x = 0; x < numTopics; x++) {
            Vector docTopicModelRow = docTopicModel.viewRow(x);
            docTopicModelRow.setQuick(e.index(), docTopicModelRow.getQuick(e.index()) * e.get());
        }
    }
    // now recalculate p(topic|doc) by summing contributions from all of pTopicGivenTerm
    topics.assign(0.0);
    for (int x = 0; x < numTopics; x++) {
        topics.set(x, docTopicModel.viewRow(x).norm(1));
    }
    // now renormalize so that sum_x(p(x|doc)) = 1
    topics.assign(Functions.mult(1 / topics.norm(1)));
}

From source file:com.elex.dmp.vectorizer.TFPartialVectorReducer.java

License:Apache License

@Override
protected void reduce(Text key, Iterable<StringTuple> values, Context context)
        throws IOException, InterruptedException {
    Iterator<StringTuple> it = values.iterator();
    if (!it.hasNext()) {
        return;//from w  w w .  j  a v a  2  s  .  c  o m
    }
    StringTuple value = it.next();

    Vector vector = new RandomAccessSparseVector(dimension, value.length()); // guess at initial size

    if (maxNGramSize >= 2) {
        ShingleFilter sf = new ShingleFilter(new IteratorTokenStream(value.getEntries().iterator()),
                maxNGramSize);
        try {
            do {
                String term = sf.getAttribute(CharTermAttribute.class).toString();
                if (!term.isEmpty() && dictionary.containsKey(term)) { // ngram
                    int termId = dictionary.get(term);
                    vector.setQuick(termId, vector.getQuick(termId) + 1);
                }
            } while (sf.incrementToken());

            sf.end();
        } finally {
            Closeables.closeQuietly(sf);
        }
    } else {
        for (String term : value.getEntries()) {
            if (!term.isEmpty() && dictionary.containsKey(term)) { // unigram
                int termId = dictionary.get(term);
                vector.setQuick(termId, vector.getQuick(termId) + 1);
            }
        }
    }
    if (sequentialAccess) {
        vector = new SequentialAccessSparseVector(vector);
    }

    if (namedVector) {
        vector = new NamedVector(vector, key.toString());
    }

    // if the vector has no nonZero entries (nothing in the dictionary), let's not waste space sending it to disk.
    if (vector.getNumNondefaultElements() > 0) {
        VectorWritable vectorWritable = new VectorWritable(vector);
        context.write(key, vectorWritable);
    } else {
        context.getCounter("TFParticalVectorReducer", "emptyVectorCount").increment(1);
    }
}

From source file:com.netease.news.classifier.naivebayes.AbstractNaiveBayesClassifier.java

License:Apache License

@Override
public Vector classifyFull(Vector r, Vector instance) {
    for (int label = 0; label < model.numLabels(); label++) {
        r.setQuick(label, getScoreForLabelInstance(label, instance));
    }/*from  ww  w  . j a  v  a 2s  . c  om*/
    return r;
}