dr.inference.model.IndianBuffetProcessPrior.java Source code

Java tutorial

Introduction

Here is the source code for dr.inference.model.IndianBuffetProcessPrior.java

Source

/*
 * IndianBuffetProcessPrior.java
 *
 * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
 *
 * This file is part of BEAST.
 * See the NOTICE file distributed with this work for additional
 * information regarding copyright ownership and licensing.
 *
 * BEAST is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 *  BEAST is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with BEAST; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301  USA
 */

package dr.inference.model;

import dr.math.Poisson;
import dr.math.distributions.PoissonDistribution;
import org.apache.commons.math.special.Beta;

/**
 * @author Max Tolkoff
 */
public class IndianBuffetProcessPrior extends AbstractModelLikelihood implements MatrixSizePrior {

    public IndianBuffetProcessPrior(Parameter alpha, Parameter beta, AdaptableSizeFastMatrixParameter data) {
        super(null);
        this.alpha = alpha;
        alpha.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0, 1));
        addVariable(alpha);
        this.beta = beta;
        beta.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0, 1));
        addVariable(beta);
        this.data = data;
        addVariable(data);
        for (int i = 0; i < data.getRowDimension(); i++) {
            if (data.getParameterValue(i, 0) != 0)
                containsNonZeroElements[0] = true;
        }
        for (int i = 0; i < data.getColumnDimension(); i++) {
            for (int j = 0; j < data.getRowDimension(); j++) {
                rowCount[i] += Math.abs(data.getParameterValue(j, i));
            }
        }
        ncols = data.getColumnDimension();
    }

    private int factorial(int num) {
        if (num < 0) {
            throw new RuntimeException("Cannot take a negative factorial");
        } else if (num == 0) {
            return 1;
        } else {
            int fac = 1;
            for (int i = 0; i < num; i++) {
                fac *= (i + 1);
            }
            return fac;
        }
    }

    private double H() {
        if (!betaKnown) {
            H = 0;
            for (int i = 0; i < data.getRowDimension(); i++) {
                H += beta.getParameterValue(0) / (beta.getParameterValue(0) + i);
            }
        }
        return H;
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int index) {

    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
        if (ncols != data.getColumnDimension()) {
            int sum = 0;
            for (int i = 0; i < data.getRowDimension(); i++) {
                sum += data.getParameterValue(i, data.getColumnDimension() - 1);
            }
            rowCount[data.getColumnDimension() - 1] = sum;
            ncols = data.getColumnDimension();
        } else {
            double value = data.getParameterValue(index);
            int col = index / data.getRowDimension();
            if (value == 0.0) {
                rowCount[col] -= 1;
                if (rowCount[col] == 0) {
                    containsNonZeroElements[col] = false;
                }
            } else {
                rowCount[col] += 1;
                containsNonZeroElements[col] = true;
            }
        }
        likelihoodKnown = false;
        if (variable == beta)
            betaKnown = false;
        if (variable == data)
            dataKnown = false;

    }

    @Override
    protected void storeState() {
        storedBetaKnown = betaKnown;
        storedContainsNonZeroElements = containsNonZeroElements;
        storedDataKnown = dataKnown;
        storedLikelihoodKnown = likelihoodKnown;
        storedLogLikelihood = logLikelihood;
        storedRowCount = rowCount;
        storedKPlus = KPlus;
        storedH = H;
        storedBottom = bottom;
        storedSum2 = sum2;
        storedncols = ncols;

    }

    @Override
    protected void restoreState() {
        betaKnown = storedBetaKnown;
        containsNonZeroElements = storedContainsNonZeroElements;
        dataKnown = storedDataKnown;
        likelihoodKnown = storedLikelihoodKnown;
        logLikelihood = storedLogLikelihood;
        rowCount = storedRowCount;
        KPlus = storedKPlus;
        H = storedH;
        bottom = storedBottom;
        sum2 = storedSum2;
        ncols = storedncols;
    }

    @Override
    protected void acceptState() {

    }

    @Override
    public Model getModel() {
        return this;
    }

    @Override
    public double getLogLikelihood() {
        if (!likelihoodKnown) {
            logLikelihood = calculateLogLikelihood();
            likelihoodKnown = true;
        }
        return logLikelihood;
    }

    private double calculateLogLikelihood() {

        int sum;

        if (!dataKnown) {
            bottom = 1;
            boolean[] isExplored = new boolean[data.getColumnDimension()];
            containsNonZeroElements = new boolean[data.getColumnDimension()];
            rowCount = new int[data.getColumnDimension()];
            boolean same;
            for (int i = 0; i < data.getColumnDimension(); i++) {
                sum = 1;
                if (!isExplored[i]) {
                    for (int j = i + 1; j < data.getColumnDimension(); j++) {
                        same = true;
                        if (!isExplored[j]) {
                            for (int k = 0; k < data.getRowDimension(); k++) {
                                if (Math.abs(data.getParameterValue(k, i)) != Math
                                        .abs(data.getParameterValue(k, j)))
                                    same = false;
                                //                                if (data.getParameterValue(k, j) != 0) {
                                //                                    containsNonZeroElements[j] = true;
                                //                                }
                                //                        rowCount[j]+=data.getParameterValue(k,j);
                            }
                        }
                        if (same && containsNonZeroElements[j]) {
                            isExplored[j] = true;
                            sum += 1;
                        } else if (!containsNonZeroElements[j]) {
                            isExplored[j] = true;
                        }
                    }
                }
                bottom *= factorial(sum);

            }
        }

        if (!dataKnown || !betaKnown) {
            sum2 = 0;
            KPlus = 0;
            for (int i = 0; i < data.getColumnDimension(); i++) {
                if (containsNonZeroElements[i]) {
                    KPlus++;
                    sum2 += Beta.logBeta(rowCount[i],
                            data.getRowDimension() + beta.getParameterValue(0) - rowCount[i]);
                }
            }
        }
        double p1 = KPlus * Math.log(alpha.getParameterValue(0) * beta.getParameterValue(0) / bottom);
        double p2 = -alpha.getParameterValue(0) * H();
        double p3 = sum2;
        betaKnown = true;
        dataKnown = true;
        return p1 + p2 + p3;
    }

    @Override
    public double getSizeLogLikelihood() {
        PoissonDistribution poisson = new PoissonDistribution(alpha.getParameterValue(0) * H());
        calculateLogLikelihood();
        return poisson.logPdf(KPlus) - Math.log(1 - Math.exp(-poisson.mean()));
    }

    public int[] getRowCount() {
        return rowCount;
    }

    public AdaptableSizeFastMatrixParameter getData() {
        return data;
    }

    @Override
    public void makeDirty() {
        betaKnown = false;
        dataKnown = false;

    }

    boolean likelihoodKnown;
    boolean storedLikelihoodKnown;
    double logLikelihood;
    double storedLogLikelihood;
    boolean betaKnown = false;
    boolean dataKnown = false;
    boolean storedDataKnown;
    boolean storedBetaKnown;
    int[] rowCount;
    int[] storedRowCount;
    int KPlus;
    int storedKPlus;
    boolean[] containsNonZeroElements;
    boolean[] storedContainsNonZeroElements;
    double H;
    double storedH;
    int bottom;
    int storedBottom;
    double sum2;
    double storedSum2;
    int ncols;
    int storedncols;

    AdaptableSizeFastMatrixParameter data;
    Parameter alpha;
    Parameter beta;
}