Example usage for org.apache.mahout.math Matrix assign

List of usage examples for org.apache.mahout.math Matrix assign

Introduction

In this page you can find the example usage for org.apache.mahout.math Matrix assign.

Prototype

Matrix assign(DoubleFunction function);

Source Link

Document

Apply the function to each element of the receiver

Usage

From source file:org.qcri.pca.PCACommon.java

/**
 * A randomly initialized matrix//from w ww .  jav  a  2s .  c  om
 * @param rows
 * @param cols
 * @return
 */
static Matrix randomMatrix(int rows, int cols) {
    Matrix randM = new DenseMatrix(rows, cols);
    randM.assign(new DoubleFunction() {
        @Override
        public double apply(double arg1) {
            return random.nextDouble();
        }
    });
    return randM;
}

From source file:org.qcri.pca.SPCADriver.java

/**
 * Run sPCA//from w  ww .  java  2  s  .  c  om
 * 
 * @param conf
 *          the configuration
 * @param input
 *          the path to the input matrix Y
 * @param output
 *          the path to the output (currently for normalization output)
 * @param nRows
 *          number of rows in input matrix
 * @param nCols
 *          number of columns in input matrix
 * @param nPCs
 *          number of desired principal components
 * @param splitFactor
 *          divide the block size by this number to increase parallelism
 * @param round
 *          the initial round index, used for naming each round output
 * @param LAST_ROUND
 *          the index of the last round
 * @param sampleRate
 *          if < 1, the input is sampled during normalization
 * @return the error
 * @throws Exception
 */
double runMapReduce(Configuration conf, DistributedRowMatrix distY, InitialValues initVal, Path output,
        final int nRows, final int nCols, final int nPCs, final int splitFactor, final float errSampleRate,
        final int LAST_ROUND, final int normalize) throws Exception {
    int round = 0;
    //The two PPCA variables that improve over each iteration
    double ss = initVal.ss;
    Matrix centralC = initVal.C;
    //initial CtC
    Matrix centralCtC = centralC.transpose().times(centralC);
    final float threshold = 0.00001f;
    int sampleRate = 1;
    //1. compute mean and span
    DenseVector ym = new DenseVector(distY.numCols()); //ym=mean(distY)
    MeanAndSpanJob masJob = new MeanAndSpanJob();
    boolean normalizeMean = false;
    if (normalize == 1)
        normalizeMean = true;
    Path meanSpanPath = masJob.compuateMeanAndSpan(distY.getRowPath(), output, ym, normalizeMean, conf,
            "" + round + "-init");
    Path normalizedYPath = null;

    //2. normalize the input matrix Y
    if (normalize == 1) {

        NormalizeJob normalizeJob = new NormalizeJob();
        normalizedYPath = normalizeJob.normalize(conf, distY.getRowPath(), meanSpanPath, output, sampleRate,
                "" + round + "-init");
        distY = new DistributedRowMatrix(normalizedYPath, getTempPath(), nRows, nCols);
        distY.setConf(conf);
        //After normalization, set the split factor
        if (splitFactor > 1) {
            FileSystem fss = FileSystem.get(normalizedYPath.toUri(), conf);
            long blockSize = fss.getDefaultBlockSize() / splitFactor;
            conf.set("mapred.max.split.size", Long.toString(blockSize));
        }
    }
    if (normalizedYPath == null)
        normalizedYPath = distY.getRowPath();

    //3. compute the 2-norm of Y
    Norm2Job normJob = new Norm2Job();
    double norm2 = normJob.computeFNorm(conf, normalizedYPath, meanSpanPath, getTempPath(),
            "" + round + "-init");
    if (sampleRate < 1) { // rescale
        norm2 = norm2 / sampleRate;
    }

    DenseVector xm = new DenseVector(nPCs);
    log.info("SSSSSSSSSSSSSSSSSSSSSSSSSSSS " + ss);
    DistributedRowMatrix distY2X = null;
    DistributedRowMatrix distC = null;
    double prevObjective = Double.MAX_VALUE;
    double error = 0;
    double relChangeInObjective = Double.MAX_VALUE;
    double prevError = Double.MAX_VALUE;
    for (; (round < LAST_ROUND && relChangeInObjective > threshold && prevError > 0.02); round++) {
        // Sx = inv( ss * eye(d) + CtC );
        Matrix centralSx = centralCtC.clone();
        centralSx.viewDiagonal().assign(Functions.plus(ss));
        centralSx = inv(centralSx);
        // X = Y * C * Sx' => Y2X = C * Sx'
        Matrix centralY2X = centralC.times(centralSx.transpose());
        distY2X = PCACommon.toDistributedRowMatrix(centralY2X, getTempPath(), getTempPath(), "CSxt" + round);
        // Xm = Ym * Y2X
        PCACommon.denseVectorTimesMatrix(ym, centralY2X, xm);
        // We skip computing X as we generate it on demand using Y and Y2X

        //Compute X'X and Y'X 
        CompositeJob compositeJob = new CompositeJob();
        compositeJob.computeYtXandXtX(distY, distY2X, ym, xm, getTempPath(), conf, "" + round);
        Matrix centralXtX = compositeJob.xtx;
        Matrix centralYtX = compositeJob.ytx;
        if (sampleRate < 1) { // rescale
            centralXtX.assign(Functions.div(sampleRate));
            centralYtX.assign(Functions.div(sampleRate));
        }

        // XtX = X'*X + ss * Sx
        final double finalss = ss;
        centralXtX.assign(centralSx, new DoubleDoubleFunction() {
            @Override
            public double apply(double arg1, double arg2) {
                return arg1 + finalss * arg2;
            }
        });
        // C = (Ye'*X) / SumXtX;
        Matrix invXtX_central = inv(centralXtX);
        centralC = centralYtX.times(invXtX_central);
        distC = PCACommon.toDistributedRowMatrix(centralC, getTempPath(), getTempPath(), "C" + round);
        centralCtC = centralC.transpose().times(centralC);

        // Compute new value for ss
        // ss = ( sum(sum(Ye.^2)) + PCACommon.trace(XtX*CtC) - 2sum(XiCtYit) )
        // /(N*D);
        double ss2 = PCACommon.trace(centralXtX.times(centralCtC));
        VarianceJob varianceJob = new VarianceJob();
        double xctyt = varianceJob.computeVariance(distY, ym, distY2X, xm, distC, getTempPath(), conf,
                "" + round);
        if (sampleRate < 1) { // rescale
            xctyt = xctyt / sampleRate;
        }
        ss = (norm2 + ss2 - 2 * xctyt) / (nRows * nCols);
        log.info("SSSSSSSSSSSSSSSSSSSSSSSSSSSS " + ss + " (" + norm2 + " + " + ss2 + " -2* " + xctyt);
        double traceSx = PCACommon.trace(centralSx);
        double traceXtX = PCACommon.trace(centralXtX);
        double traceC = PCACommon.trace(centralC);
        double traceCtC = PCACommon.trace(centralCtC);
        log.info("TTTTTTTTTTTTTTTTT " + traceSx + " " + traceXtX + " " + traceC + " " + traceCtC);

        double objective = ss;
        relChangeInObjective = Math.abs(1 - objective / prevObjective);
        prevObjective = objective;
        log.info("Objective:  %.6f    relative change: %.6f \n", objective, relChangeInObjective);
        if (!CALCULATE_ERR_ATTHEEND) {
            log.info("Computing the error at round " + round + " ...");
            ReconstructionErrJob errJob = new ReconstructionErrJob();
            error = errJob.reconstructionErr(distY, distY2X, distC, centralC, ym, xm, errSampleRate, conf,
                    getTempPath(), "" + round);
            log.info("... end of computing the error at round " + round);
            prevError = error;
        }
    }

    if (CALCULATE_ERR_ATTHEEND) {
        log.info("Computing the error at round " + round + " ...");
        ReconstructionErrJob errJob = new ReconstructionErrJob();
        error = errJob.reconstructionErr(distY, distY2X, distC, centralC, ym, xm, errSampleRate, conf,
                getTempPath(), "" + round);
        log.info("... end of computing the error at round " + round);
    }

    initVal.C = centralC;
    initVal.ss = ss;
    writeMatrix(initVal.C, output, getTempPath(), "PCs");
    return error;

}

From source file:org.qcri.pca.SPCADriver.java

private static Matrix eye(int n) {
    Matrix m = new DenseMatrix(n, n);
    m.assign(0);
    m.viewDiagonal().assign(1);//from w  w  w .  jav a2 s.c  o  m
    return m;
}

From source file:org.qcri.sparkpca.PCAUtils.java

/**
 * A randomly initialized matrix//from  w  w w .j ava 2 s .  co m
 * 
 * @param rows
 * @param cols
 * @return
 */

static Matrix randomMatrix(int rows, int cols) {
    Matrix randM = new DenseMatrix(rows, cols);
    randM.assign(new DoubleFunction() {
        @Override
        public double apply(double arg1) {
            return random.nextDouble();
        }
    });
    return randM;
}

From source file:org.qcri.sparkpca.PCAUtils.java

/**
* Initialize an identity matrix I/*from ww w  . j a v a  2  s  .  c  o  m*/
*/
private static Matrix eye(int n) {
    Matrix m = new DenseMatrix(n, n);
    m.assign(0);
    m.viewDiagonal().assign(1);
    return m;
}

From source file:org.trustedanalytics.atk.giraph.algorithms.als.AlternatingLeastSquaresComputation.java

License:Apache License

@Override
public void compute(Vertex<CFVertexId, VertexData4CFWritable, EdgeData4CFWritable> vertex,
        Iterable<MessageData4CFWritable> messages) throws IOException {
    long step = getSuperstep();
    if (step == 0) {
        initialize(vertex);//www.j  ava 2  s. com
        vertex.voteToHalt();
        return;
    }

    Vector currentValue = vertex.getValue().getVector();
    double currentBias = vertex.getValue().getBias();
    // update aggregators every (2 * interval) super steps
    if ((step % (2 * learningCurveOutputInterval)) == 0) {
        double errorOnTrain = 0d;
        double errorOnValidate = 0d;
        double errorOnTest = 0d;
        int numTrain = 0;
        for (MessageData4CFWritable message : messages) {
            EdgeType et = message.getType();
            double weight = message.getWeight();
            Vector vector = message.getVector();
            double otherBias = message.getBias();
            double predict = currentBias + otherBias + currentValue.dot(vector);
            double e = weight - predict;
            switch (et) {
            case TRAIN:
                errorOnTrain += e * e;
                numTrain++;
                break;
            case VALIDATE:
                errorOnValidate += e * e;
                break;
            case TEST:
                errorOnTest += e * e;
                break;
            default:
                throw new IllegalArgumentException("Unknown recognized edge type: " + et.toString());
            }
        }
        double costOnTrain = 0d;
        if (numTrain > 0) {
            costOnTrain = errorOnTrain / numTrain
                    + lambda * (currentBias * currentBias + currentValue.dot(currentValue));
        }
        aggregate(SUM_TRAIN_COST, new DoubleWritable(costOnTrain));
        aggregate(SUM_VALIDATE_ERROR, new DoubleWritable(errorOnValidate));
        aggregate(SUM_TEST_ERROR, new DoubleWritable(errorOnTest));
    }

    // update vertex value
    if (step < maxSupersteps) {
        // xxt records the result of x times x transpose
        Matrix xxt = new DenseMatrix(featureDimension, featureDimension);
        xxt = xxt.assign(0d);
        // xr records the result of x times rating
        Vector xr = currentValue.clone().assign(0d);
        int numTrain = 0;
        for (MessageData4CFWritable message : messages) {
            EdgeType et = message.getType();
            if (et == EdgeType.TRAIN) {
                double weight = message.getWeight();
                Vector vector = message.getVector();
                double otherBias = message.getBias();
                xxt = xxt.plus(vector.cross(vector));
                xr = xr.plus(vector.times(weight - currentBias - otherBias));
                numTrain++;
            }
        }
        xxt = xxt.plus(new DiagonalMatrix(lambda * numTrain, featureDimension));
        Matrix bMatrix = new DenseMatrix(featureDimension, 1).assignColumn(0, xr);
        Vector value = new QRDecomposition(xxt).solve(bMatrix).viewColumn(0);
        vertex.getValue().setVector(value);

        // update vertex bias
        if (biasOn) {
            double bias = computeBias(value, messages);
            vertex.getValue().setBias(bias);
        }

        // send out messages
        for (Edge<CFVertexId, EdgeData4CFWritable> edge : vertex.getEdges()) {
            MessageData4CFWritable newMessage = new MessageData4CFWritable(vertex.getValue(), edge.getValue());
            sendMessage(edge.getTargetVertexId(), newMessage);
        }
    }

    vertex.voteToHalt();
}