dr.evomodel.epidemiology.casetocase.periodpriors.NormalPeriodPriorDistribution.java Source code

Java tutorial

Introduction

Here is the source code for dr.evomodel.epidemiology.casetocase.periodpriors.NormalPeriodPriorDistribution.java

Source

/*
 * NormalPeriodPriorDistribution.java
 *
 * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
 *
 * This file is part of BEAST.
 * See the NOTICE file distributed with this work for additional
 * information regarding copyright ownership and licensing.
 *
 * BEAST is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 *  BEAST is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with BEAST; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301  USA
 */

package dr.evomodel.epidemiology.casetocase.periodpriors;

import dr.inference.loggers.LogColumn;
import dr.inference.model.Parameter;
import dr.math.distributions.NormalDistribution;
import dr.math.distributions.NormalGammaDistribution;
import dr.math.functionEval.GammaFunction;
import dr.xml.*;
import org.apache.commons.math.MathException;
import org.apache.commons.math.distribution.TDistributionImpl;

import java.util.ArrayList;
import java.util.Arrays;

/**
 The assumption here is that the periods are drawn from a normal distribution with unknown mean and variance.
 The hyperprior is the conjugate, normal-gamma distribution.
    
 @author Matthew Hall
 */
public class NormalPeriodPriorDistribution extends AbstractPeriodPriorDistribution {

    public static final String NORMAL = "normalPeriodPriorDistribution";
    public static final String LOG = "log";
    public static final String ID = "id";
    public static final String MU = "mu";
    public static final String LAMBDA = "lambda";
    public static final String ALPHA = "alpha";
    public static final String BETA = "beta";

    private NormalGammaDistribution hyperprior;

    private Parameter posteriorMean;
    private Parameter posteriorBeta;
    private Parameter posteriorExpectedPrecision;

    double normalApproximationThreshold = 30;

    private ArrayList<Double> dataValues;
    private double[] currentParameters;

    public NormalPeriodPriorDistribution(String name, boolean log, NormalGammaDistribution hyperprior) {
        super(name, log);
        this.hyperprior = hyperprior;
        posteriorBeta = new Parameter.Default(1);
        posteriorMean = new Parameter.Default(1);
        posteriorExpectedPrecision = new Parameter.Default(1);
        addVariable(posteriorBeta);
        addVariable(posteriorMean);
        addVariable(posteriorExpectedPrecision);
    }

    public NormalPeriodPriorDistribution(String name, boolean log, double mu_0, double lambda_0, double alpha_0,
            double beta_0) {
        this(name, log, new NormalGammaDistribution(mu_0, lambda_0, alpha_0, beta_0));
        reset();
    }

    public void reset() {
        dataValues = new ArrayList<Double>();
        currentParameters = hyperprior.getParameters();
        logL = 0;
    }

    // this returns the posterior predictive probability of the new value, and updates the total

    public double calculateLogPosteriorProbability(double newValue, double minValue) {
        double out = calculateLogPosteriorPredictiveProbability(newValue);
        if (minValue != Double.NEGATIVE_INFINITY) {
            out -= calculateLogPosteriorPredictiveCDF(minValue, true);
        }
        logL += out;
        update(newValue);
        return out;
    }

    public double calculateLogPosteriorCDF(double limit, boolean upper) {
        return calculateLogPosteriorPredictiveCDF(limit, upper);
    }

    public double calculateLogPosteriorPredictiveProbability(double value) {
        double mean = currentParameters[0];
        double sd = Math.sqrt(
                currentParameters[3] * (currentParameters[1] + 1) / (currentParameters[2] * currentParameters[1]));
        double scaledValue = (value - mean) / sd;
        double out;

        if (2 * currentParameters[2] <= normalApproximationThreshold) {
            TDistributionImpl tDist = new TDistributionImpl(2 * currentParameters[2]);

            out = Math.log(tDist.density(scaledValue));

        } else {

            out = NormalDistribution.logPdf(scaledValue, 0, 1);

        }

        return out;
    }

    public double calculateLogPosteriorPredictiveCDF(double value, boolean upperTail) {
        double mean = currentParameters[0];
        double sd = Math.sqrt(
                currentParameters[3] * (currentParameters[1] + 1) / (currentParameters[2] * currentParameters[1]));
        double scaledValue = (value - mean) / sd;
        double out;

        if (2 * currentParameters[2] <= normalApproximationThreshold) {
            TDistributionImpl tDist = new TDistributionImpl(2 * currentParameters[2]);

            try {
                out = upperTail ? Math.log(tDist.cumulativeProbability(-scaledValue))
                        : Math.log(tDist.cumulativeProbability(scaledValue));
            } catch (MathException e) {
                throw new RuntimeException(e.toString());
            }

        } else {

            out = upperTail ? NormalDistribution.standardCDF(-scaledValue, true)
                    : NormalDistribution.standardCDF(scaledValue, true);

        }
        return out;
    }

    private void update(double newData) {
        dataValues.add(newData);

        double[] originalParameters = hyperprior.getParameters();
        double lambda_0 = originalParameters[1];

        double oldMu = currentParameters[0];
        double oldLambda = currentParameters[1];
        double oldAlpha = currentParameters[2];
        double oldBeta = currentParameters[3];

        double count = dataValues.size();

        double newMu = (newData - oldMu) / (lambda_0 + count) + oldMu;
        double newLambda = oldLambda + 1;
        double newAlpha = oldAlpha + 0.5;
        double newBeta = oldBeta + oldLambda * Math.pow(newData - oldMu, 2) / (2 * (oldLambda + 1));

        posteriorMean.setParameterValue(0, newMu);
        posteriorBeta.setParameterValue(0, newBeta);
        posteriorExpectedPrecision.setParameterValue(0, newAlpha / newBeta);

        currentParameters = new double[] { newMu, newLambda, newAlpha, newBeta };
    }

    public double calculateLogLikelihood(double[] values) {

        int count = values.length;

        double[] infPredictiveDistributionParameters = hyperprior.getParameters();

        double mu_0 = infPredictiveDistributionParameters[0];
        double lambda_0 = infPredictiveDistributionParameters[1];
        double alpha_0 = infPredictiveDistributionParameters[2];
        double beta_0 = infPredictiveDistributionParameters[3];

        double lambda_n = lambda_0 + count;
        double alpha_n = alpha_0 + count / 2;
        double sum = 0;
        for (Double infPeriod : values) {
            sum += infPeriod;
        }
        double mean = sum / count;

        double sumOfDifferences = 0;
        for (Double infPeriod : values) {
            sumOfDifferences += Math.pow(infPeriod - mean, 2);
        }

        posteriorMean.setParameterValue(0, (lambda_0 * mu_0 + sum) / (lambda_0 + count));

        double beta_n = beta_0 + 0.5 * sumOfDifferences
                + lambda_0 * count * Math.pow(mean - mu_0, 2) / (2 * (lambda_0 + count));

        posteriorBeta.setParameterValue(0, beta_n);
        posteriorExpectedPrecision.setParameterValue(0, alpha_n / beta_n);

        logL = GammaFunction.logGamma(alpha_n) - GammaFunction.logGamma(alpha_0) + alpha_0 * Math.log(beta_0)
                - alpha_n * Math.log(beta_n) + 0.5 * Math.log(lambda_0) - 0.5 * Math.log(lambda_n)
                - (count / 2) * Math.log(2 * Math.PI);

        return logL;
    }

    public LogColumn[] getColumns() {
        ArrayList<LogColumn> columns = new ArrayList<LogColumn>(Arrays.asList(super.getColumns()));

        columns.add(new LogColumn.Abstract(getModelName() + "_posteriorMean") {
            protected String getFormattedValue() {
                return String.valueOf(posteriorMean.getParameterValue(0));
            }
        });

        columns.add(new LogColumn.Abstract(getModelName() + "_posteriorBeta") {
            protected String getFormattedValue() {
                return String.valueOf(posteriorBeta.getParameterValue(0));
            }
        });

        columns.add(new LogColumn.Abstract(getModelName() + "_posteriorExpectedPrecision") {
            protected String getFormattedValue() {
                return String.valueOf(posteriorExpectedPrecision.getParameterValue(0));
            }
        });

        return columns.toArray(new LogColumn[columns.size()]);
    }

    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {

        public String getParserName() {
            return NORMAL;
        }

        public Object parseXMLObject(XMLObject xo) throws XMLParseException {
            String id = (String) xo.getAttribute(ID);

            double mu = xo.getDoubleAttribute(MU);
            double lambda = xo.getDoubleAttribute(LAMBDA);
            double alpha = xo.getDoubleAttribute(ALPHA);
            double beta = xo.getDoubleAttribute(BETA);

            boolean log;
            log = xo.hasAttribute(LOG) ? xo.getBooleanAttribute(LOG) : false;

            return new NormalPeriodPriorDistribution(id, log, mu, lambda, alpha, beta);

        }

        public XMLSyntaxRule[] getSyntaxRules() {
            return rules;
        }

        private final XMLSyntaxRule[] rules = { AttributeRule.newBooleanRule(LOG, true),
                AttributeRule.newStringRule(ID, false), AttributeRule.newDoubleRule(MU, false),
                AttributeRule.newDoubleRule(LAMBDA, false), AttributeRule.newDoubleRule(ALPHA, false),
                AttributeRule.newDoubleRule(BETA, false) };

        public String getParserDescription() {
            return "Calculates the probability of a set of doubles being drawn from the prior posterior distribution"
                    + "of a normal distribution of unknown mean and variance";
        }

        public Class getReturnType() {
            return NormalPeriodPriorDistribution.class;
        }
    };

}