Example usage for org.apache.mahout.math Vector normalize

List of usage examples for org.apache.mahout.math Vector normalize

Introduction

In this page you can find the example usage for org.apache.mahout.math Vector normalize.

Prototype

Vector normalize(double power);

Source Link

Document

Return a new Vector containing the normalized (L_power norm) values of the recipient.

Usage

From source file:org.trustedanalytics.atk.giraph.algorithms.lbp.LoopyBeliefPropagationComputation.java

License:Apache License

@Override
public void compute(Vertex<LongWritable, VertexData4LBPWritable, DoubleWritable> vertex,
        Iterable<IdWithVectorMessage> messages) throws IOException {
    long step = getSuperstep();
    if (step == 0) {
        initializeVertex(vertex);/* www. j a  v  a  2s  .  c o  m*/
        return;
    }

    // collect messages sent to this vertex
    HashMap<Long, Vector> map = new HashMap<Long, Vector>();
    for (IdWithVectorMessage message : messages) {
        map.put(message.getData(), message.getVector());
    }

    // update posterior according to prior and messages
    VertexData4LBPWritable vertexValue = vertex.getValue();
    VertexType vt = vertexValue.getType();
    vt = ignoreVertexType ? VertexType.TRAIN : vt;
    Vector prior = vertexValue.getPriorVector();
    double nStates = prior.size();
    if (vt != VertexType.TRAIN) {
        // assign a uniform prior for validate/test vertex
        prior = prior.clone().assign(Math.log(1.0 / nStates));
    }
    // sum of prior and messages
    Vector sumPosterior = prior;
    for (IdWithVectorMessage message : messages) {
        sumPosterior = sumPosterior.plus(message.getVector());
    }
    sumPosterior = sumPosterior.plus(-sumPosterior.maxValue());
    // update posterior if this isn't an anchor vertex
    if (prior.maxValue() < anchorThreshold) {
        // normalize posterior
        Vector posterior = sumPosterior.clone().assign(Functions.EXP);
        posterior = posterior.normalize(1d);
        Vector oldPosterior = vertexValue.getPosteriorVector();
        double delta = posterior.minus(oldPosterior).norm(1d);
        // aggregate deltas
        switch (vt) {
        case TRAIN:
            aggregate(SUM_TRAIN_DELTA, new DoubleWritable(delta));
            break;
        case VALIDATE:
            aggregate(SUM_VALIDATE_DELTA, new DoubleWritable(delta));
            break;
        case TEST:
            aggregate(SUM_TEST_DELTA, new DoubleWritable(delta));
            break;
        default:
            throw new IllegalArgumentException("Unknown vertex type: " + vt.toString());
        }
        // update posterior
        vertexValue.setPosteriorVector(posterior);
    }

    if (step < maxSupersteps) {
        // if it's not a training vertex, don't send out messages
        if (vt != VertexType.TRAIN) {
            return;
        }
        IdWithVectorMessage newMessage = new IdWithVectorMessage();
        newMessage.setData(vertex.getId().get());
        // update belief
        Vector belief = prior.clone();
        for (Edge<LongWritable, DoubleWritable> edge : vertex.getEdges()) {
            double weight = edge.getValue().get();
            long id = edge.getTargetVertexId().get();
            Vector tempVector = sumPosterior;
            if (map.containsKey(id)) {
                tempVector = sumPosterior.minus(map.get(id));
            }
            for (int i = 0; i < nStates; i++) {
                double sum = 0d;
                for (int j = 0; j < nStates; j++) {
                    double msg = Math.exp(
                            tempVector.getQuick(j) + edgePotential(Math.abs(i - j) / (nStates - 1), weight));
                    if (maxProduct) {
                        sum = sum > msg ? sum : msg;
                    } else {
                        sum += msg;
                    }
                }
                belief.setQuick(i, sum > 0d ? Math.log(sum) : Double.MIN_VALUE);
            }
            belief = belief.plus(-belief.maxValue());
            newMessage.setVector(belief);
            sendMessage(edge.getTargetVertexId(), newMessage);
        }
    } else {
        // convert prior back to regular scale before output
        prior = vertexValue.getPriorVector();
        prior = prior.assign(Functions.EXP);
        vertexValue.setPriorVector(prior);
        vertex.voteToHalt();
    }
}

From source file:org.trustedanalytics.atk.giraph.algorithms.lda.CVB0LDAComputation.java

License:Apache License

/**
 * Initialize vertex/edges, collect graph statistics and send out messages
 *
 * @param vertex of the graph/*w  ww  .  j a  v a 2  s .  c om*/
 */
private void initialize(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex) {

    // initialize vertex vector, i.e., the theta for doc and phi for word in LDA
    double[] vertexValues = new double[config.numTopics()];
    vertex.getValue().setLdaResult(new DenseVector(vertexValues));

    // initialize edge vector, i.e., the gamma in LDA
    Random rand1 = new Random(vertex.getId().seed());
    long seed1 = rand1.nextInt();
    double maxDelta = 0d;
    double sumWeights = 0d;
    for (Edge<LdaVertexId, LdaEdgeData> edge : vertex.getMutableEdges()) {
        double weight = edge.getValue().getWordCount();

        // generate the random seed for this edge
        Random rand2 = new Random(edge.getTargetVertexId().seed());
        long seed2 = rand2.nextInt();
        long seed = seed1 + seed2;
        Random rand = new Random(seed);
        double[] edgeValues = new double[config.numTopics()];
        for (int i = 0; i < config.numTopics(); i++) {
            edgeValues[i] = rand.nextDouble();
        }
        Vector vector = new DenseVector(edgeValues);
        vector = vector.normalize(1d);
        edge.getValue().setVector(vector);
        // find the max delta among all edges
        double delta = vector.norm(1d) / config.numTopics();
        if (delta > maxDelta) {
            maxDelta = delta;
        }
        // the sum of weights from all edges
        sumWeights += weight;
    }
    // update vertex value
    updateVertex(vertex);
    // aggregate max delta value
    aggregate(MAX_DELTA, new DoubleWritable(maxDelta));

    // collect graph statistics
    if (vertex.getId().isDocument()) {
        aggregate(SUM_DOC_VERTEX_COUNT, new LongWritable(1));
    } else {
        aggregate(SUM_OCCURRENCE_COUNT, new DoubleWritable(sumWeights));
        aggregate(SUM_WORD_VERTEX_COUNT, new LongWritable(1));
    }

    // send out messages
    LdaMessage newMessage = new LdaMessage(vertex.getId().copy(), vertex.getValue().getLdaResult());
    sendMessageToAllEdges(vertex, newMessage);
}

From source file:org.trustedanalytics.atk.giraph.algorithms.lda.CVB0LDAComputation.java

License:Apache License

/**
 * Update edge value according to vertex and messages
 *
 * @param vertex of the graph/*from ww  w.  j ava2 s  . c om*/
 * @param map of type HashMap
 */
private void updateEdge(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex,
        HashMap<LdaVertexId, Vector> map) {
    Vector vector = vertex.getValue().getLdaResult();

    double maxDelta = 0d;
    for (Edge<LdaVertexId, LdaEdgeData> edge : vertex.getMutableEdges()) {
        Vector gamma = edge.getValue().getVector();
        LdaVertexId id = edge.getTargetVertexId();
        if (map.containsKey(id)) {
            Vector otherVector = map.get(id);
            Vector newGamma = null;
            if (vertex.getId().isDocument()) {
                newGamma = vector.minus(gamma).plus(config.alpha())
                        .times(otherVector.minus(gamma).plus(config.beta()))
                        .times(nk.minus(gamma).plus(numWords * config.beta()).assign(Functions.INV));
            } else {
                newGamma = vector.minus(gamma).plus(config.beta())
                        .times(otherVector.minus(gamma).plus(config.alpha()))
                        .times(nk.minus(gamma).plus(numWords * config.beta()).assign(Functions.INV));
            }
            newGamma = newGamma.normalize(1d);
            double delta = gamma.minus(newGamma).norm(1d) / config.numTopics();
            if (delta > maxDelta) {
                maxDelta = delta;
            }
            // update edge vector
            edge.getValue().setVector(newGamma);
        } else {
            // this happens when you don't have your Vertex Id's being setup correctly
            throw new IllegalArgumentException(
                    String.format("Vertex ID %s: A message is mis-matched.", vertex.getId()));
        }
    }
    aggregate(MAX_DELTA, new DoubleWritable(maxDelta));
}

From source file:org.trustedanalytics.atk.giraph.algorithms.lda.GiraphLdaComputation.java

License:Apache License

/**
 * Initialize vertex/edges, collect graph statistics and send out messages
 *
 * @param vertex of the graph/*from   w ww. j  ava  2  s . c om*/
 */
private void initialize(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex) {

    // initialize vertex vector, i.e., the theta for doc and phi for word in LDA
    double[] vertexValues = new double[config.numTopics()];
    vertex.getValue().setLdaResult(new DenseVector(vertexValues));
    Vector updatedVector = vertex.getValue().getLdaResult().clone().assign(0d);
    // initialize edge vector, i.e., the gamma in LDA
    Random rand1 = new Random(vertex.getId().seed());
    long seed1 = rand1.nextInt();
    double maxDelta = 0d;
    double sumWeights = 0d;
    for (Edge<LdaVertexId, LdaEdgeData> edge : vertex.getMutableEdges()) {
        double weight = edge.getValue().getWordCount();

        // generate the random seed for this edge
        Random rand2 = new Random(edge.getTargetVertexId().seed());
        long seed2 = rand2.nextInt();
        long seed = seed1 + seed2;
        Random rand = new Random(seed);
        double[] edgeValues = new double[config.numTopics()];
        for (int i = 0; i < config.numTopics(); i++) {
            edgeValues[i] = rand.nextDouble();
        }
        Vector vector = new DenseVector(edgeValues);
        vector = vector.normalize(1d);
        edge.getValue().setVector(vector);
        // find the max delta among all edges
        double delta = vector.norm(1d) / config.numTopics();
        if (delta > maxDelta) {
            maxDelta = delta;
        }
        // the sum of weights from all edges
        sumWeights += weight;
        updatedVector = updateVector(updatedVector, edge);
    }
    // update vertex value
    vertex.getValue().setLdaResult(updatedVector);
    ;
    // aggregate max delta value
    aggregateWord(vertex);
    aggregate(MAX_DELTA, new DoubleWritable(maxDelta));

    // collect graph statistics
    if (vertex.getId().isDocument()) {
        aggregate(SUM_DOC_VERTEX_COUNT, new LongWritable(1));
    } else {
        aggregate(SUM_OCCURRENCE_COUNT, new DoubleWritable(sumWeights));
        aggregate(SUM_WORD_VERTEX_COUNT, new LongWritable(1));
    }

    // send out messages
    LdaMessage newMessage = new LdaMessage(vertex.getId().copy(), vertex.getValue().getLdaResult());
    sendMessageToAllEdges(vertex, newMessage);
}

From source file:org.trustedanalytics.atk.giraph.algorithms.lda.GiraphLdaComputation.java

License:Apache License

/**
 * Update vertex and outgoing edge values using current vertex values and messages
 *
 * @param vertex of the graph/*w w  w . j a  va2 s.  co m*/
 * @param map    Map of vertices
 */
private void updateVertex(Vertex<LdaVertexId, LdaVertexData, LdaEdgeData> vertex,
        HashMap<LdaVertexId, Vector> map) {
    Vector vector = vertex.getValue().getLdaResult();
    Vector updatedVector = vertex.getValue().getLdaResult().clone().assign(0d);
    double maxDelta = 0d;
    for (Edge<LdaVertexId, LdaEdgeData> edge : vertex.getMutableEdges()) {
        Vector gamma = edge.getValue().getVector();
        LdaVertexId id = edge.getTargetVertexId();
        if (map.containsKey(id)) {
            Vector otherVector = map.get(id);
            Vector newGamma = null;
            if (vertex.getId().isDocument()) {
                newGamma = vector.minus(gamma).plus(config.alpha())
                        .times(otherVector.minus(gamma).plus(config.beta()))
                        .times(nk.minus(gamma).plus(numWords * config.beta()).assign(Functions.INV));
            } else {
                newGamma = vector.minus(gamma).plus(config.beta())
                        .times(otherVector.minus(gamma).plus(config.alpha()))
                        .times(nk.minus(gamma).plus(numWords * config.beta()).assign(Functions.INV));
            }
            newGamma = newGamma.normalize(1d);
            double delta = gamma.minus(newGamma).norm(1d) / config.numTopics();
            if (delta > maxDelta) {
                maxDelta = delta;
            }
            // update edge vector
            edge.getValue().setVector(newGamma);
        } else {
            // this happens when you don't have your Vertex Id's being setup correctly
            throw new IllegalArgumentException(
                    String.format("Vertex ID %s: A message is mis-matched.", vertex.getId()));
        }

        updatedVector = updateVector(updatedVector, edge);
    }

    vertex.getValue().setLdaResult(updatedVector);

    aggregateWord(vertex);
    aggregate(MAX_DELTA, new DoubleWritable(maxDelta));
}

From source file:org.trustedanalytics.atk.giraph.algorithms.lp.LabelPropagationComputation.java

License:Apache License

/**
 * initialize vertex and edges//from   w  ww.  j  av a  2  s  .co m
 *
 * @param vertex a graph vertex
 */
private void initializeVertexEdges(Vertex<LongWritable, VertexData4LPWritable, DoubleWritable> vertex) {

    // normalize prior and initialize posterior
    VertexData4LPWritable vertexValue = vertex.getValue();
    Vector priorValues = vertexValue.getPriorVector();
    if (null != priorValues) {
        priorValues = priorValues.normalize(1d);
        initialVectorValues = priorValues;
    } else if (initialVectorValues != null) {
        priorValues = initialVectorValues;
        vertexValue.setLabeledStatus(false);
    } else {
        throw new RuntimeException("Vector labels missing from input data for vertex " + vertex.getId()
                + ". Add edge with vertex as first column.");
    }
    vertexValue.setPriorVector(priorValues);
    vertexValue.setPosteriorVector(priorValues.clone());
    vertexValue.setDegree(initializeEdge(vertex));

    // send out messages
    IdWithVectorMessage newMessage = new IdWithVectorMessage(vertex.getId().get(), priorValues);
    sendMessageToAllEdges(vertex, newMessage);
}