vagueobjects.ir.lda.online.matrix.MatrixUtil.java Source code

Java tutorial

Introduction

Here is the source code for vagueobjects.ir.lda.online.matrix.MatrixUtil.java

Source

package vagueobjects.ir.lda.online.matrix;
/*
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 *
 */

import org.apache.commons.math.MathException;
import org.apache.commons.math.distribution.GammaDistributionImpl;
import org.apache.commons.math.special.Gamma;

import java.util.Arrays;

public class MatrixUtil {
    static GammaSampling gd = new GammaSampling();

    public static double sum(Vector v) {
        double s = 0d;
        for (double d : v.data) {
            s += d;
        }
        return s;
    }

    public static Matrix gammaLn(Matrix m) {

        int nc = m.data.length;
        int nr = m.data[0].length;
        double[][] result = new double[nc][];
        for (int k = 0; k < nc; ++k) {
            result[k] = new double[nr];
            for (int w = 0; w < nr; ++w) {
                result[k][w] = Gamma.logGamma(m.data[k][w]);
            }
        }
        return new Matrix(result);
    }

    public static Vector gammaLn(Vector v) {
        int nc = v.data.length;
        double[] result = new double[nc];
        for (int k = 0; k < nc; ++k) {
            result[k] = gammaLn(v.data[k]);
        }
        return new Vector(result);
    }

    public static double gammaLn(double d) {
        return Gamma.logGamma(d);
    }

    public static double min(Matrix m) {
        double s = Double.MAX_VALUE;
        for (double[] row : m.data) {
            for (double d : row) {
                if (d < s) {
                    s = d;
                }
            }
        }
        return s;
    }

    public static double sum(Matrix m) {
        double s = 0d;
        for (double[] row : m.data) {
            for (double d : row) {
                s += d;
            }
        }
        return s;
    }

    public static Vector exp(Vector vector) {
        return new Vector(exp(vector.data));
    }

    public static Matrix exp(Matrix matrix) {
        return new Matrix(exp(matrix.data));
    }

    private static double[][] exp(double[][] array) {
        int nc = array.length;
        int nr = array[0].length;
        double[][] result = new double[nc][];
        for (int k = 0; k < nc; ++k) {
            result[k] = new double[nr];
            for (int w = 0; w < nr; ++w) {
                result[k][w] = Math.exp(array[k][w]);
            }
        }
        return result;
    }

    public static Matrix dirichletExpectation(Matrix matrix) {
        return new Matrix(dirichletExpectation(matrix.data));
    }

    public static Vector dirichletExpectation(Vector vector) {
        return new Vector(dirichletExpectation(vector.data));
    }

    public static Matrix sampleGamma(int W, int K) {

        double[][] result = gd.batch(W, K);
        return new Matrix(result);
    }

    static double[] dirichletExpectation(double[] array) {
        double sum = 0;
        for (double d : array) {
            sum += d;
        }
        double d = Gamma.digamma(sum);
        double[] result = new double[array.length];
        for (int i = 0; i < array.length; ++i) {
            result[i] = Gamma.digamma(array[i]) - d;
        }
        return result;
    }

    /**
     * Digamma function (the first derivative of the logarithm of the gamma function).
     * @param array   - variational parameter
     * @return
     */
    static double[][] dirichletExpectation(double[][] array) {
        int numRows = array.length;
        int numCols = array[0].length;

        double[] vector = new double[numRows];
        Arrays.fill(vector, 0.0);

        for (int k = 0; k < numRows; ++k) {
            for (int w = 0; w < numCols; ++w) {
                try {
                    vector[k] += array[k][w];
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
        }
        for (int k = 0; k < numRows; ++k) {
            vector[k] = Gamma.digamma(vector[k]);
        }

        double[][] approx = new double[numRows][];
        for (int k = 0; k < numRows; ++k) {
            approx[k] = new double[numCols];
            for (int w = 0; w < numCols; ++w) {
                double z = Gamma.digamma(array[k][w]);
                approx[k][w] = z - vector[k];
            }
        }
        return approx;
    }

    private static double[] exp(double[] array) {
        int nc = array.length;
        double[] result = new double[nc];
        for (int k = 0; k < nc; ++k) {
            result[k] = Math.exp(array[k]);
        }
        return result;
    }

    static class GammaSampling {
        private final GammaDistributionImpl gd;
        public final static long DEFAULT_SEED = 1000000001L;
        public final static double DEFAULT_ALPHA = 100d;
        public final static double DEFAULT_BETA = 0.01d;

        public GammaSampling() {
            this.gd = new GammaDistributionImpl(DEFAULT_ALPHA, DEFAULT_BETA);
            this.gd.reseedRandomGenerator(DEFAULT_SEED);
        }

        double sample() {
            try {
                return gd.sample();
            } catch (MathException e) {
                throw new RuntimeException(e);
            }
        }

        public double[][] batch(int W, int K) {
            double[][] result = new double[K][];

            for (int k = 0; k < K; ++k) {
                result[k] = new double[W];
                for (int w = 0; w < W; ++w) {
                    result[k][w] = sample();
                }
            }
            return result;
        }
    }
}