Example usage for org.apache.mahout.math.function Functions INV

List of usage examples for org.apache.mahout.math.function Functions INV

Introduction

In this page you can find the example usage for org.apache.mahout.math.function Functions INV.

Prototype

DoubleFunction INV

To view the source code for org.apache.mahout.math.function Functions INV.

Click Source Link

Document

Function that returns 1.0 / a.

Usage

From source file:org.trustedanalytics.atk.giraph.algorithms.lda.CVB0LDAComputation.java

License:Apache License

/**
 * Update edge value according to vertex and messages
 *
 * @param vertex of the graph/*from w w  w .  j  a  va2 s  .  c o m*/
 * @param map of type HashMap
 */
private void updateEdge(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex,
        HashMap<LdaVertexId, Vector> map) {
    Vector vector = vertex.getValue().getLdaResult();

    double maxDelta = 0d;
    for (Edge<LdaVertexId, LdaEdgeData> edge : vertex.getMutableEdges()) {
        Vector gamma = edge.getValue().getVector();
        LdaVertexId id = edge.getTargetVertexId();
        if (map.containsKey(id)) {
            Vector otherVector = map.get(id);
            Vector newGamma = null;
            if (vertex.getId().isDocument()) {
                newGamma = vector.minus(gamma).plus(config.alpha())
                        .times(otherVector.minus(gamma).plus(config.beta()))
                        .times(nk.minus(gamma).plus(numWords * config.beta()).assign(Functions.INV));
            } else {
                newGamma = vector.minus(gamma).plus(config.beta())
                        .times(otherVector.minus(gamma).plus(config.alpha()))
                        .times(nk.minus(gamma).plus(numWords * config.beta()).assign(Functions.INV));
            }
            newGamma = newGamma.normalize(1d);
            double delta = gamma.minus(newGamma).norm(1d) / config.numTopics();
            if (delta > maxDelta) {
                maxDelta = delta;
            }
            // update edge vector
            edge.getValue().setVector(newGamma);
        } else {
            // this happens when you don't have your Vertex Id's being setup correctly
            throw new IllegalArgumentException(
                    String.format("Vertex ID %s: A message is mis-matched.", vertex.getId()));
        }
    }
    aggregate(MAX_DELTA, new DoubleWritable(maxDelta));
}

From source file:org.trustedanalytics.atk.giraph.algorithms.lda.CVB0LDAComputation.java

License:Apache License

/**
 * Normalize vertex value//from w  w w .  j a  va  2 s  .  c  o m
 *
 * @param vertex of the graph
 */
private void normalizeVertex(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex) {
    Vector vector = vertex.getValue().getLdaResult();
    if (vertex.getId().isDocument()) {
        vector = vector.plus(config.alpha()).normalize(1d);
    } else {
        vector = vector.plus(config.beta()).times(nk.plus(numWords * config.beta()).assign(Functions.INV));
    }
    // update vertex value
    vertex.getValue().setLdaResult(vector);
}

From source file:org.trustedanalytics.atk.giraph.algorithms.lda.CVB0LDAComputation.java

License:Apache License

/**
 * Evaluate cost according to vertex and messages
 *
 * @param vertex of the graph/*from   ww w  . j a  v  a 2s  .c  om*/
 * @param messages of type iterable
 * @param map of type HashMap
 */
private void evaluateCost(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex, Iterable<LdaMessage> messages,
        HashMap<LdaVertexId, Vector> map) {

    if (vertex.getId().isDocument()) {
        return;
    }
    Vector vector = vertex.getValue().getLdaResult();
    vector = vector.plus(config.beta()).times(nk.plus(numWords * config.beta()).assign(Functions.INV));

    double cost = 0d;
    for (Edge<LdaVertexId, LdaEdgeData> edge : vertex.getEdges()) {
        double weight = edge.getValue().getWordCount();
        LdaVertexId id = edge.getTargetVertexId();
        if (map.containsKey(id)) {
            Vector otherVector = map.get(id);
            otherVector = otherVector.plus(config.alpha()).normalize(1d);
            cost -= weight * Math.log(vector.dot(otherVector));
        } else {
            throw new IllegalArgumentException(
                    String.format("Vertex ID %s: A message is mis-matched", vertex.getId().getValue()));
        }
    }
    aggregate(SUM_COST, new DoubleWritable(cost));
}

From source file:org.trustedanalytics.atk.giraph.algorithms.lda.GiraphLdaComputation.java

License:Apache License

/**
 * Update vertex and outgoing edge values using current vertex values and messages
 *
 * @param vertex of the graph/*w  ww . ja  v  a2 s.  c  o m*/
 * @param map    Map of vertices
 */
private void updateVertex(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex,
        HashMap<LdaVertexId, Vector> map) {
    Vector vector = vertex.getValue().getLdaResult();
    Vector updatedVector = vertex.getValue().getLdaResult().clone().assign(0d);
    double maxDelta = 0d;
    for (Edge<LdaVertexId, LdaEdgeData> edge : vertex.getMutableEdges()) {
        Vector gamma = edge.getValue().getVector();
        LdaVertexId id = edge.getTargetVertexId();
        if (map.containsKey(id)) {
            Vector otherVector = map.get(id);
            Vector newGamma = null;
            if (vertex.getId().isDocument()) {
                newGamma = vector.minus(gamma).plus(config.alpha())
                        .times(otherVector.minus(gamma).plus(config.beta()))
                        .times(nk.minus(gamma).plus(numWords * config.beta()).assign(Functions.INV));
            } else {
                newGamma = vector.minus(gamma).plus(config.beta())
                        .times(otherVector.minus(gamma).plus(config.alpha()))
                        .times(nk.minus(gamma).plus(numWords * config.beta()).assign(Functions.INV));
            }
            newGamma = newGamma.normalize(1d);
            double delta = gamma.minus(newGamma).norm(1d) / config.numTopics();
            if (delta > maxDelta) {
                maxDelta = delta;
            }
            // update edge vector
            edge.getValue().setVector(newGamma);
        } else {
            // this happens when you don't have your Vertex Id's being setup correctly
            throw new IllegalArgumentException(
                    String.format("Vertex ID %s: A message is mis-matched.", vertex.getId()));
        }

        updatedVector = updateVector(updatedVector, edge);
    }

    vertex.getValue().setLdaResult(updatedVector);

    aggregateWord(vertex);
    aggregate(MAX_DELTA, new DoubleWritable(maxDelta));
}