ubic.basecode.math.linearmodels.MeanVarianceEstimator.java Source code

Java tutorial

Introduction

Here is the source code for ubic.basecode.math.linearmodels.MeanVarianceEstimator.java

Source

/*
 * The baseCode project
 * 
 * Copyright (c) 2011 University of British Columbia
 * 
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 * 
 * http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
 * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations under the License.
 */
package ubic.basecode.math.linearmodels;

import static cern.jet.math.Functions.chain;
import static cern.jet.math.Functions.div;
import static cern.jet.math.Functions.log2;
import static cern.jet.math.Functions.minus;
import static cern.jet.math.Functions.mult;
import static cern.jet.math.Functions.plus;
import static cern.jet.math.Functions.sqrt;

import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.analysis.interpolation.LinearInterpolator;
import org.apache.commons.math3.analysis.interpolation.LoessInterpolator;
import org.apache.commons.math3.analysis.polynomials.PolynomialSplineFunction;
import org.apache.commons.math3.exception.OutOfRangeException;

import cern.colt.function.IntIntDoubleFunction;
import cern.colt.list.DoubleArrayList;
import cern.colt.list.IntArrayList;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.jet.stat.Descriptive;
import ubic.basecode.dataStructure.matrix.DoubleMatrix;
import ubic.basecode.math.DescriptiveWithMissing;
import ubic.basecode.math.linalg.QRDecomposition;

/**
 * Estimate mean-variance relationship and use this to compute weights for least squares fitting. R's limma.voom()
 * Charity Law and Gordon Smyth. See Law et al.
 * {@link http://genomebiology.biomedcentral.com/articles/10.1186/gb-2014-15-2-r29}
 * <p>
 * Running voom() on data matrices with NaNs is not currently supported.
 * 
 * @author ptan
 */
public class MeanVarianceEstimator {

    /**
     * Default loess span (This is the default value used by limma-voom)
     */
    public static final double BANDWIDTH = 0.5;

    /**
     * Default number of loess robustness iterations; 0 is probably fine.
     */
    public static final int ROBUSTNESS_ITERS = 3;

    /**
     * Similar implementation of R's stats.approxfun(..., rule = 2) where values outside the interval ['min(x)',
     * 'max(x)'] gets the value at the closest data extreme. Also performs sorting based on xTrain.
     * 
     * @param x the training set of x values
     * @param y the training set of y values
     * @param xInterpolate the set of x values to interpolate
     * @return yInterpolate the interpolated set of y values
     */
    protected static double[] approx(double[] x, double[] y, double[] xInterpolate) {

        assert x != null;
        assert y != null;
        assert xInterpolate != null;
        assert x.length == y.length;

        double[] yInterpolate = new double[xInterpolate.length];
        LinearInterpolator linearInterpolator = new LinearInterpolator();

        // make sure that x is strictly increasing
        DoubleMatrix2D matrix = new DenseDoubleMatrix2D(x.length, 2);
        matrix.viewColumn(0).assign(x);
        matrix.viewColumn(1).assign(y);
        matrix = matrix.viewSorted(0);
        double[] sortedX = matrix.viewColumn(0).toArray();
        double[] sortedY = matrix.viewColumn(1).toArray();

        // make sure x is within the domain
        DoubleArrayList xList = new DoubleArrayList(sortedX);
        double x3ListMin = Descriptive.min(xList);
        double x3ListMax = Descriptive.max(xList);
        PolynomialSplineFunction fun = linearInterpolator.interpolate(sortedX, sortedY);
        for (int i = 0; i < xInterpolate.length; i++) {
            try {
                // approx(...,rule=2)
                if (xInterpolate[i] > x3ListMax) {
                    yInterpolate[i] = fun.value(x3ListMax);
                } else if (xInterpolate[i] < x3ListMin) {
                    yInterpolate[i] = fun.value(x3ListMin);
                } else {
                    yInterpolate[i] = fun.value(xInterpolate[i]);
                }
            } catch (OutOfRangeException e) {
                // this shouldn't happen anymore
                yInterpolate[i] = Double.NaN;
            }
        }

        return yInterpolate;
    }

    /**
     * Normalized variables on log2 scale
     */
    private DoubleMatrix2D E;

    /**
     * Size of each library (column).
     */
    private DoubleMatrix1D librarySize;

    /**
     * Loess fit (x, y)
     */
    private DoubleMatrix2D loess;

    /**
     * Matrix that contains the mean and variance of the data. Matrix is sorted by increasing mean. Useful for plotting.
     * mean <- fit$Amean + mean(log2(lib.size+1)) - log2(1e6) variance <- sqrt(fit$sigma)
     */
    private DoubleMatrix2D meanVariance;

    /**
     * inverse variance weights
     */
    private DoubleMatrix2D weights = null;

    /**
     * Preferred interface if you want control over how the design is set up. Executes voom() to calculate weights.
     * 
     * @param designMatrix
     * @param data a normalized count matrix
     * @param librarySize library size (matrix column sum)
     */
    public MeanVarianceEstimator(DesignMatrix designMatrix, DoubleMatrix<String, String> data,
            DoubleMatrix1D librarySize) {

        DoubleMatrix2D b = new DenseDoubleMatrix2D(data.asArray());
        this.librarySize = librarySize;
        this.E = b;

        mv();
        voom(designMatrix.getDoubleMatrix());
    }

    /**
     * Executes voom() to calculate weights.
     * 
     * @param designMatrix
     * @param data a normalized count matrix
     * @param librarySize library size (matrix column sum)
     */
    public MeanVarianceEstimator(DesignMatrix designMatrix, DoubleMatrix2D data, DoubleMatrix1D librarySize) {

        this.librarySize = librarySize;
        this.E = data;

        mv();
        voom(designMatrix.getDoubleMatrix());
    }

    /**
     * Generic method for calculating mean, variance and the loess fit. voom() is not executed and therefore no weights
     * are calculated.
     * 
     * @param data a normalized count matrix
     */
    public MeanVarianceEstimator(DoubleMatrix2D data) {
        this.E = data;

        mv();
    }

    /**
     * @return total library size
     */
    public DoubleMatrix1D getLibrarySize() {
        return this.librarySize;
    }

    /**
     * @return the loess fit of the mean-variance relationship
     */
    public DoubleMatrix2D getLoess() {
        return this.loess;
    }

    /**
     * @return the mean and variance of the normalized data, columns 0 and 1 respectively
     */
    public DoubleMatrix2D getMeanVariance() {
        return this.meanVariance;
    }

    /**
     * @return log2 counts per million. t(log2(t(counts+0.5)/(lib.size+1)*1e6))
     */
    public DoubleMatrix2D getNormalizedValue() {
        return this.E;
    }

    /**
     * @return inverse variance weights or null if the DesignMatrix was not provided.
     */
    public DoubleMatrix2D getWeights() {
        return this.weights;
    }

    /**
     * First ensures that x values are strictly increasing and performs a loess fit afterwards. The loess fit are
     * determined by <code>BANDWIDTH</code> and <code>ROBUSTNESS_ITERS</code>.
     * 
     * @param xy
     * @return loessFit or null if there are less than 3 data points
     */
    private DoubleMatrix2D loessFit(DoubleMatrix2D xy) {
        assert xy != null;

        DoubleMatrix1D sx = xy.viewColumn(0);
        DoubleMatrix1D sy = xy.viewColumn(1);
        Map<Double, Double> map = new TreeMap<>();
        for (int i = 0; i < sx.size(); i++) {
            if (Double.isNaN(sx.get(i)) || Double.isInfinite(sx.get(i)) || Double.isNaN(sy.get(i))
                    || Double.isInfinite(sy.get(i))) {
                continue;
            }
            map.put(sx.get(i), sy.get(i));
        }
        DoubleMatrix2D xyChecked = new DenseDoubleMatrix2D(map.size(), 2);
        xyChecked.viewColumn(0).assign(ArrayUtils.toPrimitive(map.keySet().toArray(new Double[0])));
        xyChecked.viewColumn(1).assign(ArrayUtils.toPrimitive(map.values().toArray(new Double[0])));

        // in R:
        // loess(c(1:5),c(1:5)^2,f=0.5,iter=3)
        // Note: we start to loose some precision here in comparison with R's loess FIXME why? does it matter?
        DoubleMatrix2D loessFit = new DenseDoubleMatrix2D(xyChecked.rows(), xyChecked.columns());
        // try {
        // fit a loess curve
        LoessInterpolator loessInterpolator = new LoessInterpolator(MeanVarianceEstimator.BANDWIDTH,
                MeanVarianceEstimator.ROBUSTNESS_ITERS);

        double[] loessY = loessInterpolator.smooth(xyChecked.viewColumn(0).toArray(),
                xyChecked.viewColumn(1).toArray());

        loessFit.viewColumn(0).assign(xyChecked.viewColumn(0));
        loessFit.viewColumn(1).assign(loessY);

        return loessFit;
    }

    /**
     * Performs row-wise mean (x) and variance (y) and performs a loess fit. Note this lowess fit is different than the
     * one we use with voom, which is fit to the quarter-root variances (as per Smythe; tends to be more symmetric).
     * Handles missing data.
     * <p>
     * FIXME I'm not sure the lowess fits are useful since we don't use them for analysis, and they are suboptimal being
     * fit to the variance rather than the quarter-root variance.
     */
    private void mv() {
        assert this.E != null;

        // mean-variance
        DoubleMatrix1D Amean = new DenseDoubleMatrix1D(E.rows());
        DoubleMatrix1D variance = Amean.like();
        for (int i = 0; i < Amean.size(); i++) {
            DoubleArrayList row = new DoubleArrayList(E.viewRow(i).toArray());
            double rowMean = DescriptiveWithMissing.mean(row);
            double rowVar = DescriptiveWithMissing.variance(row);
            Amean.set(i, rowMean);
            variance.set(i, rowVar);

        }

        this.meanVariance = new DenseDoubleMatrix2D(E.rows(), 2);
        this.meanVariance.viewColumn(0).assign(Amean);
        this.meanVariance.viewColumn(1).assign(variance);

        /*
         * fit a loess curve.
         */
        this.loess = loessFit(this.meanVariance);
    }

    /**
     * Performs the heavy duty work of calculating the weights. See Law et al.
     * {@link http://genomebiology.biomedcentral.com/articles/10.1186/gb-2014-15-2-r29}
     * 
     * @param designMatrix
     * @throws IllegalArgumentException if there are missing values.
     */
    private void voom(DoubleMatrix2D designMatrix) {
        assert designMatrix != null;
        assert this.meanVariance != null;
        assert this.E != null;
        assert this.librarySize != null;

        Algebra solver = new Algebra();

        DoubleMatrix2D A = designMatrix;
        weights = new DenseDoubleMatrix2D(E.rows(), E.columns());

        // perform a linear fit to obtain the mean-variance relationship
        // fit3<-lm(t(yCpm) ~ as.matrix(design.matrix[,2]))
        // or gFit <- lmFit(yCpm, design=design.matrix)
        LeastSquaresFit lsf = new LeastSquaresFit(A, E);

        // calculate fit$Amean by doing rowSums(CPM) (see limma.getEAWP())
        DoubleMatrix1D Amean = this.meanVariance.viewColumn(0);

        // sx <- fit$Amean + mean(log2(lib.size+1)) - log2(1e6)
        DoubleMatrix1D sx = Amean.copy();
        sx.assign(plus(librarySize.copy().assign(chain(log2, plus(1))).zSum() / librarySize.size()));
        sx.assign(minus(Math.log(Math.pow(10, 6)) / Math.log(2)));

        // help("MArrayLM-class")
        // fit$sigma <- sqrt(sum(out$residuals^2)/out$df.residual)
        // sy <- sqrt(fit$sigma)
        DoubleMatrix2D residuals = lsf.getResiduals();
        DoubleMatrix1D sy = new DenseDoubleMatrix1D(residuals.rows());
        for (int row = 0; row < residuals.rows(); row++) {
            double sum = 0;
            for (int column = 0; column < residuals.columns(); column++) {
                Double val = residuals.get(row, column);
                if (!Double.isNaN(val)) {
                    sum += val * val;
                }
            }
            sy.set(row, sum);
        }
        // if you have missing values in the expression matrix
        // you'll get a residual dof of 0
        if (lsf.isHasMissing()) {
            // calculate it per row
            List<Integer> dofs = lsf.getResidualDofs();
            assert dofs.size() == sy.size();
            for (int i = 0; i < sy.size(); i++) {
                sy.set(i, Math.sqrt(sy.get(i) / dofs.get(i)));
            }
        } else {
            int dof = lsf.getResidualDof();
            assert dof != 0;
            sy.assign(chain(sqrt, div(dof)));
        }
        sy.assign(sqrt); // we're fitting the quarter-root variances.

        // only accepts array in strictly increasing order (drop duplicates)
        // so combine sx and sy and sort
        DoubleMatrix2D voomXY = new DenseDoubleMatrix2D(sx.size(), 2);
        voomXY.viewColumn(0).assign(sx);
        voomXY.viewColumn(1).assign(sy);
        DoubleMatrix2D fit = loessFit(voomXY);

        // quarterroot fitted counts
        DoubleMatrix2D fittedValues = null;
        QRDecomposition qr = new QRDecomposition(A);
        DoubleMatrix2D coeff = lsf.getCoefficients();

        if (qr.getRank() < A.columns()) {
            // j <- fit$pivot[1:fit$rank]
            // fitted.values <- fit$coef[,j,drop=F] %*% t(fit$design[,j,drop=F]);
            IntArrayList pivot = qr.getPivotOrder();

            IntArrayList subindices = (IntArrayList) pivot.partFromTo(0, qr.getRank() - 1);
            int[] coeffAllCols = new int[coeff.columns()];
            int[] desAllRows = new int[A.rows()];
            for (int i = 0; i < coeffAllCols.length; i++) {
                coeffAllCols[i] = i;
            }
            for (int i = 0; i < desAllRows.length; i++) {
                desAllRows[i] = i;
            }
            DoubleMatrix2D coeffSlice = coeff.viewSelection(subindices.elements(), coeffAllCols);
            DoubleMatrix2D ASlice = A.viewSelection(desAllRows, subindices.elements());
            fittedValues = solver.mult(coeffSlice.viewDice(), ASlice.viewDice());
        } else {
            // fitted.values <- fit$coef %*% t(fit$design)
            fittedValues = solver.mult(coeff.viewDice(), A.viewDice());
        }

        // fitted.cpm <- 2^fitted.values
        // fitted.count <- 1e-6 * t(t(fitted.cpm)*(lib.size+1))
        // fitted.logcount <- log2(fitted.count)
        DoubleMatrix2D fittedCpm = fittedValues.copy().forEachNonZero(new IntIntDoubleFunction() {
            @Override
            public double apply(int row, int column, double third) {
                return Math.pow(2, third);
            }
        });
        DoubleMatrix2D fittedCount = fittedCpm.copy();
        DoubleMatrix1D libSizePlusOne = librarySize.assign(plus(1));
        for (int i = 0; i < fittedCount.rows(); i++) {
            fittedCount.viewRow(i).assign(libSizePlusOne, mult);
            fittedCount.viewRow(i).assign(mult(Math.pow(10, -6)));
        }
        DoubleMatrix2D fittedLogCount = fittedCount.copy().assign(log2);

        // interpolate points using the loess curve
        // f <- approxfun(l, rule=2)
        // apply trend to individual observations
        // w <- 1 / f(fitted.logcount)^4
        // 2D to 1D
        double[] xInterpolate = new double[fittedLogCount.rows() * fittedLogCount.columns()];
        int idx = 0;
        for (int col = 0; col < fittedLogCount.columns(); col++) {
            for (int row = 0; row < fittedLogCount.rows(); row++) {
                xInterpolate[idx] = fittedLogCount.get(row, col);
                idx++;
            }
        }
        assert fit != null;
        double[] yInterpolate = MeanVarianceEstimator.approx(fit.viewColumn(0).toArray(),
                fit.viewColumn(1).toArray(), xInterpolate);

        // 1D to 2D
        idx = 0;
        for (int col = 0; col < weights.columns(); col++) {
            for (int row = 0; row < weights.rows(); row++) {
                weights.set(row, col, (1.0 / Math.pow(yInterpolate[idx], 4)));
                idx++;
            }
        }

    }
}