Example usage for org.apache.mahout.math Matrix plus

List of usage examples for org.apache.mahout.math Matrix plus

Introduction

In this page you can find the example usage for org.apache.mahout.math Matrix plus.

Prototype

Matrix plus(Matrix x);

Source Link

Document

Return a new matrix containing the element by element sum of the recipient and the argument

Usage

From source file:org.qcri.pca.SPCADriver.java

/**
 * Run PPCA sequentially given the small input Y which fit into memory This
 * could be used also on sampled data from a distributed matrix
 * //from  w  w  w.  j  a v  a2s  .  c  o  m
 * Note: this implementation ignore NaN values by replacing them with 0
 * 
 * @param conf
 *          the configuration
 * @param centralY
 *          the input matrix
 * @param initVal
 *          the initial values for C and ss
 * @param MAX_ROUNDS
 *          maximum number of iterations
 * @return the error
 * @throws Exception
 */
double runSequential_JacobVersion(Configuration conf, Matrix centralY, InitialValues initVal,
        final int MAX_ROUNDS) {
    Matrix centralC = initVal.C;// the current implementation doesn't use initial ss of
    // initVal
    final int nRows = centralY.numRows();
    final int nCols = centralY.numCols();
    final int nPCs = centralC.numCols();
    final float threshold = 0.00001f;

    log.info("tracec= " + PCACommon.trace(centralC));
    // Y = Y - mean(Ye)
    // Also normalize the matrix
    for (int r = 0; r < nRows; r++)
        for (int c = 0; c < nCols; c++)
            if (new Double(centralY.getQuick(r, c)).isNaN()) {
                centralY.setQuick(r, c, 0);
            }
    Vector mean = centralY.aggregateColumns(new VectorFunction() {
        @Override
        public double apply(Vector v) {
            return v.zSum() / nRows;
        }
    });
    Vector spanVector = new DenseVector(nCols);
    for (int c = 0; c < nCols; c++) {
        Vector col = centralY.viewColumn(c);
        double max = col.maxValue();
        double min = col.minValue();
        double span = max - min;
        spanVector.setQuick(c, span);
    }
    for (int r = 0; r < nRows; r++)
        for (int c = 0; c < nCols; c++)
            centralY.set(r, c, (centralY.get(r, c) - mean.get(c))
                    / (spanVector.getQuick(c) != 0 ? spanVector.getQuick(c) : 1));

    // -------------------------- initialization
    // CtC = C'*C;
    Matrix centralCtC = centralC.transpose().times(centralC);
    log.info("tracectc= " + PCACommon.trace(centralCtC));
    log.info("traceinvctc= " + PCACommon.trace(inv(centralCtC)));
    log.info("traceye= " + PCACommon.trace(centralY));
    // X = Ye * C * inv(CtC);
    Matrix centralX = centralY.times(centralC).times(inv(centralCtC));
    log.info("tracex= " + PCACommon.trace(centralX));
    // recon = X * C';
    Matrix recon = centralX.times(centralC.transpose());
    log.info("tracerec= " + PCACommon.trace(recon));
    // ss = sum(sum((recon-Ye).^2)) / (N*D-missing);
    double ss = recon.minus(centralY).assign(new DoubleFunction() {
        @Override
        public double apply(double arg1) {
            return arg1 * arg1;
        }
    }).zSum() / (nRows * nCols);
    log.info("SSSSSSSSSSSSSSSSSSSSSSSSSSSS " + ss);

    int count = 1;
    // old = Inf;
    double old = Double.MAX_VALUE;
    // -------------------------- EM Iterations
    // while count
    int round = 0;
    while (round < MAX_ROUNDS && count > 0) {
        round++;
        // ------------------ E-step, (co)variances
        // Sx = inv( eye(d) + CtC/ss );
        Matrix centralSx = eye(nPCs).plus(centralCtC.divide(ss));
        centralSx = inv(centralSx);
        // ------------------ E-step expected value
        // X = Ye*C*(Sx/ss);
        centralX = centralY.times(centralC).times(centralSx.divide(ss));
        // ------------------ M-step
        // SumXtX = X'*X;
        Matrix centralSumXtX = centralX.transpose().times(centralX);
        // C = (Ye'*X) / (SumXtX + N*Sx );
        Matrix tmpInv = inv(centralSumXtX.plus(centralSx.times(nRows)));
        centralC = centralY.transpose().times(centralX).times(tmpInv);
        // CtC = C'*C;
        centralCtC = centralC.transpose().times(centralC);
        // ss = ( sum(sum( (X*C'-Ye).^2 )) + N*sum(sum(CtC.*Sx)) +
        // missing*ss_old ) /(N*D);
        recon = centralX.times(centralC.transpose());
        double error = recon.minus(centralY).assign(new DoubleFunction() {
            @Override
            public double apply(double arg1) {
                return arg1 * arg1;
            }
        }).zSum();
        ss = error + nRows * dot(centralCtC.clone(), centralSx).zSum();
        ss /= (nRows * nCols);

        log.info("SSSSSSSSSSSSSSSSSSSSSSSSSSSS " + ss);
        double traceSx = PCACommon.trace(centralSx);
        double traceX = PCACommon.trace(centralX);
        double traceSumXtX = PCACommon.trace(centralSumXtX);
        double traceC = PCACommon.trace(centralC);
        double traceCtC = PCACommon.trace(centralCtC);
        log.info("TTTTTTTTTTTTTTTTT " + traceSx + " " + traceX + " " + traceSumXtX + " " + traceC + " "
                + traceCtC + " " + 0);

        // objective = N*D + N*(D*log(ss) +PCACommon.trace(Sx)-log(det(Sx)) )
        // +PCACommon.trace(SumXtX) -missing*log(ss_old);
        double objective = nRows * nCols + nRows
                * (nCols * Math.log(ss) + PCACommon.trace(centralSx) - Math.log(centralSx.determinant()))
                + PCACommon.trace(centralSumXtX);
        double rel_ch = Math.abs(1 - objective / old);
        old = objective;
        count++;
        if (rel_ch < threshold && count > 5)
            count = 0;
        System.out.printf("Objective:  %.6f    relative change: %.6f \n", objective, rel_ch);
    }

    double norm1Y = centralY.aggregateColumns(new VectorNorm1()).maxValue();
    log.info("Norm1 of Y is: " + norm1Y);
    Matrix newYerror = centralY.minus(centralX.times(centralC.transpose()));
    double norm1Err = newYerror.aggregateColumns(new VectorNorm1()).maxValue();
    log.info("Norm1 of the reconstruction error is: " + norm1Err);

    initVal.C = centralC;
    initVal.ss = ss;
    return norm1Err / norm1Y;
}

From source file:org.trustedanalytics.atk.giraph.algorithms.als.AlternatingLeastSquaresComputation.java

License:Apache License

@Override
public void compute(Vertex<CFVertexId, VertexData4CFWritable, EdgeData4CFWritable> vertex,
        Iterable<MessageData4CFWritable> messages) throws IOException {
    long step = getSuperstep();
    if (step == 0) {
        initialize(vertex);/*from  w w w . j  a v a  2s  .co  m*/
        vertex.voteToHalt();
        return;
    }

    Vector currentValue = vertex.getValue().getVector();
    double currentBias = vertex.getValue().getBias();
    // update aggregators every (2 * interval) super steps
    if ((step % (2 * learningCurveOutputInterval)) == 0) {
        double errorOnTrain = 0d;
        double errorOnValidate = 0d;
        double errorOnTest = 0d;
        int numTrain = 0;
        for (MessageData4CFWritable message : messages) {
            EdgeType et = message.getType();
            double weight = message.getWeight();
            Vector vector = message.getVector();
            double otherBias = message.getBias();
            double predict = currentBias + otherBias + currentValue.dot(vector);
            double e = weight - predict;
            switch (et) {
            case TRAIN:
                errorOnTrain += e * e;
                numTrain++;
                break;
            case VALIDATE:
                errorOnValidate += e * e;
                break;
            case TEST:
                errorOnTest += e * e;
                break;
            default:
                throw new IllegalArgumentException("Unknown recognized edge type: " + et.toString());
            }
        }
        double costOnTrain = 0d;
        if (numTrain > 0) {
            costOnTrain = errorOnTrain / numTrain
                    + lambda * (currentBias * currentBias + currentValue.dot(currentValue));
        }
        aggregate(SUM_TRAIN_COST, new DoubleWritable(costOnTrain));
        aggregate(SUM_VALIDATE_ERROR, new DoubleWritable(errorOnValidate));
        aggregate(SUM_TEST_ERROR, new DoubleWritable(errorOnTest));
    }

    // update vertex value
    if (step < maxSupersteps) {
        // xxt records the result of x times x transpose
        Matrix xxt = new DenseMatrix(featureDimension, featureDimension);
        xxt = xxt.assign(0d);
        // xr records the result of x times rating
        Vector xr = currentValue.clone().assign(0d);
        int numTrain = 0;
        for (MessageData4CFWritable message : messages) {
            EdgeType et = message.getType();
            if (et == EdgeType.TRAIN) {
                double weight = message.getWeight();
                Vector vector = message.getVector();
                double otherBias = message.getBias();
                xxt = xxt.plus(vector.cross(vector));
                xr = xr.plus(vector.times(weight - currentBias - otherBias));
                numTrain++;
            }
        }
        xxt = xxt.plus(new DiagonalMatrix(lambda * numTrain, featureDimension));
        Matrix bMatrix = new DenseMatrix(featureDimension, 1).assignColumn(0, xr);
        Vector value = new QRDecomposition(xxt).solve(bMatrix).viewColumn(0);
        vertex.getValue().setVector(value);

        // update vertex bias
        if (biasOn) {
            double bias = computeBias(value, messages);
            vertex.getValue().setBias(bias);
        }

        // send out messages
        for (Edge<CFVertexId, EdgeData4CFWritable> edge : vertex.getEdges()) {
            MessageData4CFWritable newMessage = new MessageData4CFWritable(vertex.getValue(), edge.getValue());
            sendMessage(edge.getTargetVertexId(), newMessage);
        }
    }

    vertex.voteToHalt();
}