Example usage for org.apache.mahout.math Vector get

List of usage examples for org.apache.mahout.math Vector get

Introduction

In this page you can find the example usage for org.apache.mahout.math Vector get.

Prototype

double get(int index);

Source Link

Document

Return the value at the given index

Usage

From source file:DisplayClustering.java

License:Apache License

protected static void plotSampleParameters(Graphics2D g2) {
    Vector v = new DenseVector(2);
    Vector dv = new DenseVector(2);
    g2.setColor(Color.RED);// w  ww.  jav a2s.  c om
    for (Vector param : SAMPLE_PARAMS) {
        v.set(0, param.get(0));
        v.set(1, param.get(1));
        dv.set(0, param.get(2) * 3);
        dv.set(1, param.get(3) * 3);
        plotEllipse(g2, v, dv);
    }
}

From source file:DisplayClustering.java

License:Apache License

/**
 * Identical to plotRectangle(), but with the option of setting the color of
 * the rectangle's stroke.//w w w  .j av a 2  s .  c om
 *
 * NOTE: This should probably be refactored with plotRectangle() since most of
 * the code here is direct copy/paste from that method.
 *
 * @param g2
 *          A Graphics2D context.
 * @param v
 *          A vector for the rectangle's center.
 * @param dv
 *          A vector for the rectangle's dimensions.
 * @param color
 *          The color of the rectangle's stroke.
 */
protected static void plotClusteredRectangle(Graphics2D g2, Vector v, Vector dv, Color color) {
    double[] flip = { 1, -1 };
    Vector v2 = v.times(new DenseVector(flip));
    v2 = v2.minus(dv.divide(2));
    int h = SIZE / 2;
    double x = v2.get(0) + h;
    double y = v2.get(1) + h;

    g2.setStroke(new BasicStroke(1));
    g2.setColor(color);
    g2.draw(new Rectangle2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS));
}

From source file:DisplayClustering.java

License:Apache License

/**
 * Draw a rectangle on the graphics context
 *
 * @param g2//from   ww  w .  ja  va  2 s  .com
 *          a Graphics2D context
 * @param v
 *          a Vector of rectangle center
 * @param dv
 *          a Vector of rectangle dimensions
 */
protected static void plotRectangle(Graphics2D g2, Vector v, Vector dv) {
    double[] flip = { 1, -1 };
    Vector v2 = v.times(new DenseVector(flip));
    v2 = v2.minus(dv.divide(2));
    int h = SIZE / 2;
    double x = v2.get(0) + h;
    double y = v2.get(1) + h;
    g2.draw(new Rectangle2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS));
}

From source file:DisplayClustering.java

License:Apache License

/**
 * Draw an ellipse on the graphics context
 *
 * @param g2//from  w ww .  j  a v a  2  s.  com
 *          a Graphics2D context
 * @param v
 *          a Vector of ellipse center
 * @param dv
 *          a Vector of ellipse dimensions
 */
protected static void plotEllipse(Graphics2D g2, Vector v, Vector dv) {
    double[] flip = { 1, -1 };
    Vector v2 = v.times(new DenseVector(flip));
    v2 = v2.minus(dv.divide(2));
    int h = SIZE / 2;
    double x = v2.get(0) + h;
    double y = v2.get(1) + h;
    g2.draw(new Ellipse2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS));
}

From source file:at.illecker.hama.rootbeer.examples.matrixmultiplication.compositeinput.gpu.MatrixMultiplicationBSPGpu.java

License:Apache License

@Override
public void bsp(BSPPeer<IntWritable, TupleWritable, IntWritable, VectorWritable, MatrixRowMessage> peer)
        throws IOException, SyncException, InterruptedException {

    IntWritable key = new IntWritable();
    TupleWritable value = new TupleWritable();
    while (peer.readNext(key, value)) {

        // Logging
        if (isDebuggingEnabled) {
            for (int i = 0; i < value.size(); i++) {
                Vector vector = ((VectorWritable) value.get(i)).get();
                logger.writeChars("bsp,input,key=" + key + ",value=" + vector.toString() + "\n");
            }/*from  ww w. j  av a  2  s .c  om*/
        }

        Vector firstVector = ((VectorWritable) value.get(0)).get();
        Vector secondVector = ((VectorWritable) value.get(1)).get();

        // outCardinality is resulting column size n
        // (l x m) * (m x n) = (l x n)
        boolean firstIsOutFrag = secondVector.size() == outCardinality;

        // outFrag is Matrix which has the resulting column cardinality
        // (matrixB)
        Vector outFrag = firstIsOutFrag ? secondVector : firstVector;

        // multiplier is Matrix which has the resulting row count
        // (transposed matrixA)
        Vector multiplier = firstIsOutFrag ? firstVector : secondVector;

        if (isDebuggingEnabled) {
            logger.writeChars("bsp,firstIsOutFrag=" + firstIsOutFrag + "\n");
            logger.writeChars("bsp,outFrag=" + outFrag + "\n");
            logger.writeChars("bsp,multiplier=" + multiplier + "\n");
        }

        // outFrag to double[]
        double[] outFragArray = new double[outFrag.size()];
        int i = 0;
        for (Vector.Element e : outFrag.all()) {
            outFragArray[i] = e.get();
            i++;
        }

        // One map task consists of multiple kernels within one block
        // Each kernel computes a scalar multiplication
        blockSize = multiplier.size();
        gridSize++;

        for (int j = 0; j < blockSize; j++) {
            kernels.add(new MatrixMultiplicationBSPKernel(j, multiplier.get(j), outFragArray));
        }

        // Run GPU Kernels
        Rootbeer rootbeer = new Rootbeer();
        Context context = rootbeer.createDefaultContext();
        Stopwatch watch = new Stopwatch();
        watch.start();
        // blockSize = rows of Matrix A (multiplier)
        // gridSize = cols of Matrix B (for each row a scalar multiplication
        // has to be made)
        rootbeer.run(kernels, new ThreadConfig(blockSize, gridSize, kernels.size()), context);
        watch.stop();

        List<StatsRow> stats = context.getStats();
        for (StatsRow row : stats) {
            System.out.println("  StatsRow:\n");
            System.out.println("    serial time: " + row.getSerializationTime() + "\n");
            System.out.println("    exec time: " + row.getExecutionTime() + "\n");
            System.out.println("    deserial time: " + row.getDeserializationTime() + "\n");
            System.out.println("    num blocks: " + row.getNumBlocks() + "\n");
            System.out.println("    num threads: " + row.getNumThreads() + "\n");
        }

        if (isDebuggingEnabled) {
            logger.writeChars(
                    "bsp,KernelCount=" + kernels.size() + ",GPUTime=" + watch.elapsedTimeMillis() + "ms\n");
            logger.writeChars("bps,blockSize=" + blockSize + ",gridSize=" + gridSize + "\n");
            logger.flush();
        }

        // Collect results of GPU kernels
        for (Kernel kernel : kernels) {
            MatrixMultiplicationBSPKernel bspKernel = (MatrixMultiplicationBSPKernel) kernel;

            if (isDebuggingEnabled) {
                logger.writeChars("bsp,thread_idxx=" + bspKernel.thread_idxx + ",multiplier="
                        + bspKernel.multiplierVal + ",vector=" + Arrays.toString(bspKernel.vectorVal) + "\n");
            }

            peer.send(masterTask, new MatrixRowMessage(bspKernel.row,
                    new VectorWritable(new DenseVector(bspKernel.results))));

            if (isDebuggingEnabled) {
                logger.writeChars("bsp,send,key=" + bspKernel.row + ",value="
                        + Arrays.toString(bspKernel.results) + "\n");
            }
        }
    }
    peer.sync();
}

From source file:cn.edu.bjtu.cit.recommender.Recommender.java

License:Apache License

@SuppressWarnings("unchecked")
public int run(String[] args) throws Exception {
    if (args.length < 2) {
        System.err.println();/*from  ww  w  .j a v a 2s .c o m*/
        System.err.println("Usage: " + this.getClass().getName()
                + " [generic options] input output [profiling] [estimation] [clustersize]");
        System.err.println();
        printUsage();
        GenericOptionsParser.printGenericCommandUsage(System.err);

        return 1;
    }
    OptionParser parser = new OptionParser(args);

    Pipeline pipeline = new MRPipeline(Recommender.class, getConf());

    if (parser.hasOption(CLUSTER_SIZE)) {
        pipeline.getConfiguration().setInt(ClusterOracle.CLUSTER_SIZE,
                Integer.parseInt(parser.getOption(CLUSTER_SIZE).getValue()));
    }

    if (parser.hasOption(PROFILING)) {
        pipeline.getConfiguration().setBoolean(Profiler.IS_PROFILE, true);
        this.profileFilePath = parser.getOption(PROFILING).getValue();

    }

    if (parser.hasOption(ESTIMATION)) {
        estFile = parser.getOption(ESTIMATION).getValue();
        est = new Estimator(estFile, clusterSize);
    }

    if (parser.hasOption(OPT_REDUCE)) {
        pipeline.getConfiguration().setBoolean(OPT_REDUCE, true);
    }

    if (parser.hasOption(OPT_MSCR)) {
        pipeline.getConfiguration().setBoolean(OPT_MSCR, true);
    }

    if (parser.hasOption(ACTIVE_THRESHOLD)) {
        threshold = Integer.parseInt(parser.getOption("at").getValue());
    }

    if (parser.hasOption(TOP)) {
        top = Integer.parseInt(parser.getOption("top").getValue());
    }

    profiler = new Profiler(pipeline);
    /*
     * input node
     */
    PCollection<String> lines = pipeline.readTextFile(args[0]);

    if (profiler.isProfiling() && lines.getSize() > 10 * 1024 * 1024) {
        lines = lines.sample(0.1);
    }

    /*
     * S0 + GBK
     */
    PGroupedTable<Long, Long> userWithPrefs = lines.parallelDo(new MapFn<String, Pair<Long, Long>>() {

        @Override
        public Pair<Long, Long> map(String input) {
            String[] split = input.split(Estimator.DELM);
            long userID = Long.parseLong(split[0]);
            long itemID = Long.parseLong(split[1]);
            return Pair.of(userID, itemID);
        }

        @Override
        public float scaleFactor() {
            return est.getScaleFactor("S0").sizeFactor;
        }

        @Override
        public float scaleFactorByRecord() {
            return est.getScaleFactor("S0").recsFactor;
        }
    }, Writables.tableOf(Writables.longs(), Writables.longs())).groupByKey(est.getClusterSize());

    /*
     * S1
     */
    PTable<Long, Vector> userVector = userWithPrefs
            .parallelDo(new MapFn<Pair<Long, Iterable<Long>>, Pair<Long, Vector>>() {
                @Override
                public Pair<Long, Vector> map(Pair<Long, Iterable<Long>> input) {
                    Vector userVector = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
                    for (long itemPref : input.second()) {
                        userVector.set((int) itemPref, 1.0f);
                    }
                    return Pair.of(input.first(), userVector);
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("S1").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("S1").recsFactor;
                }
            }, Writables.tableOf(Writables.longs(), Writables.vectors()));

    userVector = profiler.profile("S0-S1", pipeline, userVector, ProfileConverter.long_vector(),
            Writables.tableOf(Writables.longs(), Writables.vectors()));

    /*
     * S2
     */
    PTable<Long, Vector> filteredUserVector = userVector
            .parallelDo(new DoFn<Pair<Long, Vector>, Pair<Long, Vector>>() {

                @Override
                public void process(Pair<Long, Vector> input, Emitter<Pair<Long, Vector>> emitter) {
                    if (input.second().getNumNondefaultElements() > threshold) {
                        emitter.emit(input);
                    }
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("S2").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("S2").recsFactor;
                }

            }, Writables.tableOf(Writables.longs(), Writables.vectors()));

    filteredUserVector = profiler.profile("S2", pipeline, filteredUserVector, ProfileConverter.long_vector(),
            Writables.tableOf(Writables.longs(), Writables.vectors()));

    /*
     * S3 + GBK
     */
    PGroupedTable<Integer, Integer> coOccurencePairs = filteredUserVector
            .parallelDo(new DoFn<Pair<Long, Vector>, Pair<Integer, Integer>>() {
                @Override
                public void process(Pair<Long, Vector> input, Emitter<Pair<Integer, Integer>> emitter) {
                    Iterator<Vector.Element> it = input.second().iterateNonZero();
                    while (it.hasNext()) {
                        int index1 = it.next().index();
                        Iterator<Vector.Element> it2 = input.second().iterateNonZero();
                        while (it2.hasNext()) {
                            int index2 = it2.next().index();
                            emitter.emit(Pair.of(index1, index2));
                        }
                    }
                }

                @Override
                public float scaleFactor() {
                    float size = est.getScaleFactor("S3").sizeFactor;
                    return size;
                }

                @Override
                public float scaleFactorByRecord() {
                    float recs = est.getScaleFactor("S3").recsFactor;
                    return recs;
                }
            }, Writables.tableOf(Writables.ints(), Writables.ints())).groupByKey(est.getClusterSize());

    /*
     * S4
     */
    PTable<Integer, Vector> coOccurenceVector = coOccurencePairs
            .parallelDo(new MapFn<Pair<Integer, Iterable<Integer>>, Pair<Integer, Vector>>() {
                @Override
                public Pair<Integer, Vector> map(Pair<Integer, Iterable<Integer>> input) {
                    Vector cooccurrenceRow = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
                    for (int itemIndex2 : input.second()) {
                        cooccurrenceRow.set(itemIndex2, cooccurrenceRow.get(itemIndex2) + 1.0);
                    }
                    return Pair.of(input.first(), cooccurrenceRow);
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("S4").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("S4").recsFactor;
                }
            }, Writables.tableOf(Writables.ints(), Writables.vectors()));

    coOccurenceVector = profiler.profile("S3-S4", pipeline, coOccurenceVector, ProfileConverter.int_vector(),
            Writables.tableOf(Writables.ints(), Writables.vectors()));

    /*
     * S5 Wrapping co-occurrence columns
     */
    PTable<Integer, VectorOrPref> wrappedCooccurrence = coOccurenceVector
            .parallelDo(new MapFn<Pair<Integer, Vector>, Pair<Integer, VectorOrPref>>() {

                @Override
                public Pair<Integer, VectorOrPref> map(Pair<Integer, Vector> input) {
                    return Pair.of(input.first(), new VectorOrPref(input.second()));
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("S5").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("S5").recsFactor;
                }

            }, Writables.tableOf(Writables.ints(), VectorOrPref.vectorOrPrefs()));

    wrappedCooccurrence = profiler.profile("S5", pipeline, wrappedCooccurrence, ProfileConverter.int_vopv(),
            Writables.tableOf(Writables.ints(), VectorOrPref.vectorOrPrefs()));

    /*
     * S6 Splitting user vectors
     */
    PTable<Integer, VectorOrPref> userVectorSplit = filteredUserVector
            .parallelDo(new DoFn<Pair<Long, Vector>, Pair<Integer, VectorOrPref>>() {

                @Override
                public void process(Pair<Long, Vector> input, Emitter<Pair<Integer, VectorOrPref>> emitter) {
                    long userID = input.first();
                    Vector userVector = input.second();
                    Iterator<Vector.Element> it = userVector.iterateNonZero();
                    while (it.hasNext()) {
                        Vector.Element e = it.next();
                        int itemIndex = e.index();
                        float preferenceValue = (float) e.get();
                        emitter.emit(Pair.of(itemIndex, new VectorOrPref(userID, preferenceValue)));
                    }
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("S6").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("S6").recsFactor;
                }
            }, Writables.tableOf(Writables.ints(), VectorOrPref.vectorOrPrefs()));

    userVectorSplit = profiler.profile("S6", pipeline, userVectorSplit, ProfileConverter.int_vopp(),
            Writables.tableOf(Writables.ints(), VectorOrPref.vectorOrPrefs()));

    /*
     * S7 Combine VectorOrPrefs
     */
    PTable<Integer, VectorAndPrefs> combinedVectorOrPref = wrappedCooccurrence.union(userVectorSplit)
            .groupByKey(est.getClusterSize())
            .parallelDo(new DoFn<Pair<Integer, Iterable<VectorOrPref>>, Pair<Integer, VectorAndPrefs>>() {

                @Override
                public void process(Pair<Integer, Iterable<VectorOrPref>> input,
                        Emitter<Pair<Integer, VectorAndPrefs>> emitter) {
                    Vector vector = null;
                    List<Long> userIDs = Lists.newArrayList();
                    List<Float> values = Lists.newArrayList();
                    for (VectorOrPref vop : input.second()) {
                        if (vector == null) {
                            vector = vop.getVector();
                        }
                        long userID = vop.getUserID();
                        if (userID != Long.MIN_VALUE) {
                            userIDs.add(vop.getUserID());
                        }
                        float value = vop.getValue();
                        if (!Float.isNaN(value)) {
                            values.add(vop.getValue());
                        }
                    }
                    emitter.emit(Pair.of(input.first(), new VectorAndPrefs(vector, userIDs, values)));
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("S7").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("S7").recsFactor;
                }
            }, Writables.tableOf(Writables.ints(), VectorAndPrefs.vectorAndPrefs()));

    combinedVectorOrPref = profiler.profile("S5+S6-S7", pipeline, combinedVectorOrPref,
            ProfileConverter.int_vap(), Writables.tableOf(Writables.ints(), VectorAndPrefs.vectorAndPrefs()));
    /*
     * S8 Computing partial recommendation vectors
     */
    PTable<Long, Vector> partialMultiply = combinedVectorOrPref
            .parallelDo(new DoFn<Pair<Integer, VectorAndPrefs>, Pair<Long, Vector>>() {
                @Override
                public void process(Pair<Integer, VectorAndPrefs> input, Emitter<Pair<Long, Vector>> emitter) {
                    Vector cooccurrenceColumn = input.second().getVector();
                    List<Long> userIDs = input.second().getUserIDs();
                    List<Float> prefValues = input.second().getValues();
                    for (int i = 0; i < userIDs.size(); i++) {
                        long userID = userIDs.get(i);
                        if (userID != Long.MIN_VALUE) {
                            float prefValue = prefValues.get(i);
                            Vector partialProduct = cooccurrenceColumn.times(prefValue);
                            emitter.emit(Pair.of(userID, partialProduct));
                        }
                    }
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("S8").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("S8").recsFactor;
                }

            }, Writables.tableOf(Writables.longs(), Writables.vectors())).groupByKey(est.getClusterSize())
            .combineValues(new CombineFn<Long, Vector>() {

                @Override
                public void process(Pair<Long, Iterable<Vector>> input, Emitter<Pair<Long, Vector>> emitter) {
                    Vector partial = null;
                    for (Vector vector : input.second()) {
                        partial = partial == null ? vector : partial.plus(vector);
                    }
                    emitter.emit(Pair.of(input.first(), partial));
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("combine").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("combine").recsFactor;
                }
            });

    partialMultiply = profiler.profile("S8-combine", pipeline, partialMultiply, ProfileConverter.long_vector(),
            Writables.tableOf(Writables.longs(), Writables.vectors()));

    /*
     * S9 Producing recommendations from vectors
     */
    PTable<Long, RecommendedItems> recommendedItems = partialMultiply
            .parallelDo(new DoFn<Pair<Long, Vector>, Pair<Long, RecommendedItems>>() {

                @Override
                public void process(Pair<Long, Vector> input, Emitter<Pair<Long, RecommendedItems>> emitter) {
                    Queue<RecommendedItem> topItems = new PriorityQueue<RecommendedItem>(11,
                            Collections.reverseOrder(BY_PREFERENCE_VALUE));
                    Iterator<Vector.Element> recommendationVectorIterator = input.second().iterateNonZero();
                    while (recommendationVectorIterator.hasNext()) {
                        Vector.Element element = recommendationVectorIterator.next();
                        int index = element.index();
                        float value = (float) element.get();
                        if (topItems.size() < top) {
                            topItems.add(new GenericRecommendedItem(index, value));
                        } else if (value > topItems.peek().getValue()) {
                            topItems.add(new GenericRecommendedItem(index, value));
                            topItems.poll();
                        }
                    }
                    List<RecommendedItem> recommendations = new ArrayList<RecommendedItem>(topItems.size());
                    recommendations.addAll(topItems);
                    Collections.sort(recommendations, BY_PREFERENCE_VALUE);
                    emitter.emit(Pair.of(input.first(), new RecommendedItems(recommendations)));
                }

                @Override
                public float scaleFactor() {
                    return est.getScaleFactor("S9").sizeFactor;
                }

                @Override
                public float scaleFactorByRecord() {
                    return est.getScaleFactor("S9").recsFactor;
                }

            }, Writables.tableOf(Writables.longs(), RecommendedItems.recommendedItems()));

    recommendedItems = profiler.profile("S9", pipeline, recommendedItems, ProfileConverter.long_ri(),
            Writables.tableOf(Writables.longs(), RecommendedItems.recommendedItems()));

    /*
     * Profiling
     */
    if (profiler.isProfiling()) {
        profiler.writeResultToFile(profileFilePath);
        profiler.cleanup(pipeline.getConfiguration());
        return 0;
    }
    /*
     * asText
     */
    pipeline.writeTextFile(recommendedItems, args[1]);
    PipelineResult result = pipeline.done();
    return result.succeeded() ? 0 : 1;
}

From source file:com.cloudera.knittingboar.records.TestRCV1RecordFactory.java

License:Apache License

public void testParse() throws Exception {

    RCV1RecordFactory factory = new RCV1RecordFactory();

    Vector v = new RandomAccessSparseVector(RCV1RecordFactory.FEATURES);

    int actual = factory.processLine(training_rec_0, v);

    assertEquals(0, actual);/*w  ww.  j  av  a 2  s .  c  o  m*/
    assertEquals(.043696374, v.get(7));

    Vector v2 = new RandomAccessSparseVector(RCV1RecordFactory.FEATURES);

    int actual2 = factory.processLine(training_rec_1, v2);

    assertEquals(1, actual2);
    assertEquals(.030852484, v2.get(69));

}

From source file:com.cloudera.knittingboar.sgd.ParallelOnlineLogisticRegression.java

License:Apache License

/**
 * Custom training for POLR based around accumulating gradient to send to the
 * master process/* www  . j  a va  2  s  .c o  m*/
 * 
 * 
 */
@Override
public void train(long trackingKey, String groupKey, int actual, Vector instance) {
    unseal();
    double learningRate = currentLearningRate();

    // push coefficients back to zero based on the prior
    regularize(instance);

    // basically this only gets the results for each classification
    // update each row of coefficients according to result
    Vector gradient = this.default_gradient.apply(groupKey, actual, instance, this);
    for (int i = 0; i < numCategories - 1; i++) {

        double gradientBase = gradient.get(i);

        // we're only going to look at the non-zero elements of the vector
        // then we apply the gradientBase to the resulting element.
        Iterator<Vector.Element> nonZeros = instance.iterateNonZero();

        while (nonZeros.hasNext()) {
            Vector.Element updateLocation = nonZeros.next();
            int j = updateLocation.index();

            double gradient_to_add = gradientBase * learningRate * perTermLearningRate(j) * instance.get(j);

            // double old_beta = beta.getQuick(i, j);

            double newValue = beta.getQuick(i, j)
                    + gradientBase * learningRate * perTermLearningRate(j) * instance.get(j);
            beta.setQuick(i, j, newValue);

            // now update gamma --- we only want the gradient since the last time

            double old_gamma = gamma.getCell(i, j);
            double new_gamma = old_gamma + gradient_to_add; // gradientBase *
                                                            // learningRate *
                                                            // perTermLearningRate(j)
                                                            // * instance.get(j);

            gamma.setCell(i, j, new_gamma);

        }
    }

    // remember that these elements got updated
    Iterator<Vector.Element> i = instance.iterateNonZero();
    while (i.hasNext()) {
        Vector.Element element = i.next();
        int j = element.index();
        updateSteps.setQuick(j, getStep());
        updateCounts.setQuick(j, updateCounts.getQuick(j) + 1);
    }
    nextStep();

}

From source file:com.elex.dmp.core.TopicModel.java

License:Apache License

public Vector infer(Vector original, Vector docTopics) {
    Vector pTerm = original.like();
    Iterator<Vector.Element> it = original.iterateNonZero();
    while (it.hasNext()) {
        Vector.Element e = it.next();
        int term = e.index();
        // p(a) = sum_x (p(a|x) * p(x|i))
        double pA = 0;
        for (int x = 0; x < numTopics; x++) {
            pA += (topicTermCounts.viewRow(x).get(term) / topicSums.get(x)) * docTopics.get(x);
        }//from  w  w  w.j  av a 2s  .com
        pTerm.set(term, pA);
    }
    return pTerm;
}

From source file:com.elex.dmp.core.TopicModel.java

License:Apache License

public void update(int termId, Vector topicCounts) {
    for (int x = 0; x < numTopics; x++) {
        Vector v = topicTermCounts.viewRow(x);
        v.set(termId, v.get(termId) + topicCounts.get(x));
    }/*from  w ww.  j  a va2  s  .  co  m*/
    topicSums.assign(topicCounts, Functions.PLUS);
}