es.csic.iiia.planes.util.InverseWishartDistribution.java Source code

Java tutorial

Introduction

Here is the source code for es.csic.iiia.planes.util.InverseWishartDistribution.java

Source

/*
 * Software License Agreement (BSD License)
 *
 * Copyright 2013 Marc Pujol <mpujol@iiia.csic.es>.
 *
 * Redistribution and use of this software in source and binary forms, with or
 * without modification, are permitted provided that the following conditions
 * are met:
 *
 *   Redistributions of source code must retain the above
 *   copyright notice, this list of conditions and the
 *   following disclaimer.
 *
 *   Redistributions in binary form must reproduce the above
 *   copyright notice, this list of conditions and the
 *   following disclaimer in the documentation and/or other
 *   materials provided with the distribution.
 *
 *   Neither the name of IIIA-CSIC, Artificial Intelligence Research Institute
 *   nor the names of its contributors may be used to
 *   endorse or promote products derived from this
 *   software without specific prior written permission of
 *   IIIA-CSIC, Artificial Intelligence Research Institute
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */
package es.csic.iiia.planes.util;

import java.util.Arrays;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.CholeskyDecomposition;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.Well19937c;

/**
 * Inverse Wishart distribution implementation, to sample random covariances matrices for
 * multivariate gaussian distributions.
 * <p/>
 * The sampling method follows the procedure described by Odell & Feiveson, 1966 to get samples
 * from a Wishart distribution, and then computes the inverse of the obtained samples.
 *
 * @author Marc Pujol <mpujol@iiia.csic.es>
 */
public class InverseWishartDistribution {
    private static final Logger LOG = Logger.getLogger(InverseWishartDistribution.class.getName());

    private GammaDistribution[] gammas;
    private double df;
    private RealMatrix scaleMatrix;
    private CholeskyDecomposition cholesky;
    private RandomGenerator random;

    /**
     * Builds a new Inverse Wishart distribution with the given scale and degrees of freedom.
     *
     * @param scaleMatrix scale matrix.
     * @param df degrees of freedom.
     */
    public InverseWishartDistribution(RealMatrix scaleMatrix, double df) {
        if (!scaleMatrix.isSquare()) {
            throw new RuntimeException("scaleMatrix must be square.");
        }

        this.scaleMatrix = scaleMatrix;
        this.df = df;
        this.random = new Well19937c();
        initialize();
    }

    private void initialize() {
        final int dim = scaleMatrix.getColumnDimension();

        // Build gamma distributions for the diagonal
        gammas = new GammaDistribution[dim];
        for (int i = 0; i < dim; i++) {

            gammas[i] = new GammaDistribution(df - i - .99 / 2, 2);
        }

        // Build the cholesky decomposition
        cholesky = new CholeskyDecomposition(scaleMatrix);
    }

    /**
     * Reseeds the random generator.
     *
     * @param seed new random seed.
     */
    public void reseedRandomGenerator(long seed) {
        random.setSeed(seed);
        for (int i = 0, len = scaleMatrix.getColumnDimension(); i < len; i++) {
            gammas[i].reseedRandomGenerator(seed + i);
        }
    }

    /**
     * Returns a sample matrix from this distribution.
     * @return sampled matrix.
     */
    public RealMatrix sample() {
        for (int i = 0; i < 100; i++) {
            try {
                RealMatrix A = sampleWishart();
                RealMatrix result = new LUDecomposition(A).getSolver().getInverse();
                LOG.log(Level.FINE, "Cov = {0}", result);
                return result;
            } catch (SingularMatrixException ex) {
                LOG.finer("Discarding singular matrix generated by the wishart distribution.");
            }
        }
        throw new RuntimeException("Unable to generate inverse wishart samples!");
    }

    private RealMatrix sampleWishart() {
        final int dim = scaleMatrix.getColumnDimension();

        // Build N_{ij}
        double[][] N = new double[dim][dim];
        for (int j = 0; j < dim; j++) {
            for (int i = 0; i < j; i++) {
                N[i][j] = random.nextGaussian();
            }
        }
        if (LOG.isLoggable(Level.FINEST)) {
            LOG.log(Level.FINEST, "N = {0}", Arrays.deepToString(N));
        }

        // Build V_j
        double[] V = new double[dim];
        for (int i = 0; i < dim; i++) {
            V[i] = gammas[i].sample();
        }
        if (LOG.isLoggable(Level.FINEST)) {
            LOG.log(Level.FINEST, "V = {0}", Arrays.toString(V));
        }

        // Build B
        double[][] B = new double[dim][dim];

        // b_{11} = V_1 (first j, where sum = 0 because i == j and the inner
        //               loop is never entered).
        // b_{jj} = V_j + \sum_{i=1}^{j-1} N_{ij}^2, j = 2, 3, ..., p
        for (int j = 0; j < dim; j++) {
            double sum = 0;
            for (int i = 0; i < j; i++) {
                sum += Math.pow(N[i][j], 2);
            }
            B[j][j] = V[j] + sum;
        }
        if (LOG.isLoggable(Level.FINEST)) {
            LOG.log(Level.FINEST, "B*_jj : = {0}", Arrays.deepToString(B));
        }

        // b_{1j} = N_{1j} * \sqrt V_1
        for (int j = 1; j < dim; j++) {
            B[0][j] = N[0][j] * Math.sqrt(V[0]);
            B[j][0] = B[0][j];
        }
        if (LOG.isLoggable(Level.FINEST)) {
            LOG.log(Level.FINEST, "B*_1j = {0}", Arrays.deepToString(B));
        }

        // b_{ij} = N_{ij} * \sqrt V_1 + \sum_{k=1}^{i-1} N_{ki}*N_{kj}
        for (int j = 1; j < dim; j++) {
            for (int i = 1; i < j; i++) {
                double sum = 0;
                for (int k = 0; k < i; k++) {
                    sum += N[k][i] * N[k][j];
                }
                B[i][j] = N[i][j] * Math.sqrt(V[i]) + sum;
                B[j][i] = B[i][j];
            }
        }
        if (LOG.isLoggable(Level.FINEST)) {
            LOG.log(Level.FINEST, "B* = {0}", Arrays.deepToString(B));
        }

        RealMatrix BMat = new Array2DRowRealMatrix(B);
        RealMatrix A = cholesky.getL().multiply(BMat).multiply(cholesky.getLT());
        if (LOG.isLoggable(Level.FINER)) {
            LOG.log(Level.FINER, "A* = {0}", Arrays.deepToString(N));
        }
        A = A.scalarMultiply(1 / df);
        return A;
    }

}