Example usage for org.apache.mahout.math.function Functions EXP

List of usage examples for org.apache.mahout.math.function Functions EXP

Introduction

In this page you can find the example usage for org.apache.mahout.math.function Functions EXP.

Prototype

DoubleFunction EXP

To view the source code for org.apache.mahout.math.function Functions EXP.

Click Source Link

Document

Function that returns Math.exp(a).

Usage

From source file:org.trustedanalytics.atk.giraph.algorithms.lbp.LoopyBeliefPropagationComputation.java

License:Apache License

@Override
public void compute(Vertex<LongWritable, VertexData4LBPWritable, DoubleWritable> vertex,
        Iterable<IdWithVectorMessage> messages) throws IOException {
    long step = getSuperstep();
    if (step == 0) {
        initializeVertex(vertex);//  w  ww  . jav  a  2 s.  co  m
        return;
    }

    // collect messages sent to this vertex
    HashMap<Long, Vector> map = new HashMap<Long, Vector>();
    for (IdWithVectorMessage message : messages) {
        map.put(message.getData(), message.getVector());
    }

    // update posterior according to prior and messages
    VertexData4LBPWritable vertexValue = vertex.getValue();
    VertexType vt = vertexValue.getType();
    vt = ignoreVertexType ? VertexType.TRAIN : vt;
    Vector prior = vertexValue.getPriorVector();
    double nStates = prior.size();
    if (vt != VertexType.TRAIN) {
        // assign a uniform prior for validate/test vertex
        prior = prior.clone().assign(Math.log(1.0 / nStates));
    }
    // sum of prior and messages
    Vector sumPosterior = prior;
    for (IdWithVectorMessage message : messages) {
        sumPosterior = sumPosterior.plus(message.getVector());
    }
    sumPosterior = sumPosterior.plus(-sumPosterior.maxValue());
    // update posterior if this isn't an anchor vertex
    if (prior.maxValue() < anchorThreshold) {
        // normalize posterior
        Vector posterior = sumPosterior.clone().assign(Functions.EXP);
        posterior = posterior.normalize(1d);
        Vector oldPosterior = vertexValue.getPosteriorVector();
        double delta = posterior.minus(oldPosterior).norm(1d);
        // aggregate deltas
        switch (vt) {
        case TRAIN:
            aggregate(SUM_TRAIN_DELTA, new DoubleWritable(delta));
            break;
        case VALIDATE:
            aggregate(SUM_VALIDATE_DELTA, new DoubleWritable(delta));
            break;
        case TEST:
            aggregate(SUM_TEST_DELTA, new DoubleWritable(delta));
            break;
        default:
            throw new IllegalArgumentException("Unknown vertex type: " + vt.toString());
        }
        // update posterior
        vertexValue.setPosteriorVector(posterior);
    }

    if (step < maxSupersteps) {
        // if it's not a training vertex, don't send out messages
        if (vt != VertexType.TRAIN) {
            return;
        }
        IdWithVectorMessage newMessage = new IdWithVectorMessage();
        newMessage.setData(vertex.getId().get());
        // update belief
        Vector belief = prior.clone();
        for (Edge<LongWritable, DoubleWritable> edge : vertex.getEdges()) {
            double weight = edge.getValue().get();
            long id = edge.getTargetVertexId().get();
            Vector tempVector = sumPosterior;
            if (map.containsKey(id)) {
                tempVector = sumPosterior.minus(map.get(id));
            }
            for (int i = 0; i < nStates; i++) {
                double sum = 0d;
                for (int j = 0; j < nStates; j++) {
                    double msg = Math.exp(
                            tempVector.getQuick(j) + edgePotential(Math.abs(i - j) / (nStates - 1), weight));
                    if (maxProduct) {
                        sum = sum > msg ? sum : msg;
                    } else {
                        sum += msg;
                    }
                }
                belief.setQuick(i, sum > 0d ? Math.log(sum) : Double.MIN_VALUE);
            }
            belief = belief.plus(-belief.maxValue());
            newMessage.setVector(belief);
            sendMessage(edge.getTargetVertexId(), newMessage);
        }
    } else {
        // convert prior back to regular scale before output
        prior = vertexValue.getPriorVector();
        prior = prior.assign(Functions.EXP);
        vertexValue.setPriorVector(prior);
        vertex.voteToHalt();
    }
}