dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.java Source code

Java tutorial

Introduction

Here is the source code for dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.java

Source

/*
 * CaseToCaseTreeLikelihood.java
 *
 * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
 *
 * This file is part of BEAST.
 * See the NOTICE file distributed with this work for additional
 * information regarding copyright ownership and licensing.
 *
 * BEAST is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 *  BEAST is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with BEAST; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301  USA
 */

package dr.evomodel.epidemiology.casetocase;

import java.io.IOException;
import java.io.PrintStream;
import java.util.*;

import dr.app.tools.NexusExporter;
import dr.evolution.tree.FlexibleNode;
import dr.evolution.tree.FlexibleTree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evolution.util.Taxon;
import dr.evolution.util.TaxonList;
import dr.evomodel.epidemiology.casetocase.periodpriors.AbstractPeriodPriorDistribution;
import dr.evomodel.tree.TreeModel;
import dr.oldevomodel.treelikelihood.AbstractTreeLikelihood;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import org.apache.commons.math.stat.descriptive.moment.Mean;
import org.apache.commons.math.stat.descriptive.moment.Variance;
import org.apache.commons.math.stat.descriptive.rank.Median;

/**
 * Handles manipulation of the tree partition, and likelihood of the infection times.
 *
 * @author Matthew Hall
 * @author Andrew Rambaut
 * @version $Id: $
 */

public abstract class CaseToCaseTreeLikelihood extends AbstractTreeLikelihood
        implements Loggable, Citable, TreeTraitProvider {

    protected static final boolean DEBUG = false;

    protected static double tolerance = 1E-10;

    /* The phylogenetic tree. */

    protected int noTips;
    protected int noCases;

    /* Mapping of outbreak to branches on the tree; old version is stored before operators are applied */

    /* Matches outbreak to external nodes */

    private double estimatedLastSampleTime;
    protected TreeTraitProvider.Helper treeTraits = new Helper();

    /**
     * The set of cases
     */
    protected AbstractOutbreak outbreak;

    protected double[] infectionTimes;
    private double[] storedInfectionTimes;
    protected double[] infectiousPeriods;
    private double[] storedInfectiousPeriods;
    protected double[] infectiousTimes;
    private double[] storedInfectiousTimes;
    protected double[] latentPeriods;
    private double[] storedLatentPeriods;
    protected boolean[] recalculateCaseFlags;

    protected HashMap<AbstractCase, Treelet> elementsAsTrees;
    protected HashMap<AbstractCase, Treelet> storedElementsAsTrees;

    //because of the way the former works, we need a maximum value of the time from first infection to root node.

    protected Parameter maxFirstInfToRoot;

    // latent periods

    protected boolean hasLatentPeriods;

    // PUBLIC STUFF

    // Name

    public static final String CASE_TO_CASE_TREE_LIKELIHOOD = "caseToCaseTreeLikelihood";
    public static final String PARTITIONS_KEY = "partition";

    // Basic constructor.

    public CaseToCaseTreeLikelihood(PartitionedTreeModel tree, AbstractOutbreak caseData,
            Parameter maxFirstInfToRoot) throws TaxonList.MissingTaxonException {
        this(CASE_TO_CASE_TREE_LIKELIHOOD, tree, caseData, maxFirstInfToRoot);
    }

    // Constructor for an instance with a non-default name

    public CaseToCaseTreeLikelihood(String name, PartitionedTreeModel tree, AbstractOutbreak caseData,
            Parameter maxFirstInfToRoot) {

        super(name, caseData, tree);

        if (stateCount != treeModel.getExternalNodeCount()) {
            throw new RuntimeException("There are duplicate tip outbreak.");
        }

        noTips = tree.getExternalNodeCount();

        //subclasses should add outbreak as a model if it contains any information that ever changes

        outbreak = caseData;

        noCases = outbreak.getCases().size();

        addModel(outbreak);

        estimatedLastSampleTime = getLatestTaxonTime();

        //map outbreak to tips

        addModel(tree.getBranchMap());

        hasLatentPeriods = outbreak.hasLatentPeriods();

        infectionTimes = new double[outbreak.size()];
        infectiousPeriods = new double[outbreak.size()];

        if (hasLatentPeriods) {
            infectiousTimes = new double[outbreak.size()];
            latentPeriods = new double[outbreak.size()];
        }

        recalculateCaseFlags = new boolean[outbreak.size()];
        Arrays.fill(recalculateCaseFlags, true);

        this.maxFirstInfToRoot = maxFirstInfToRoot;

        treeTraits.addTrait(PARTITIONS_KEY, new TreeTrait.S() {
            public String getTraitName() {
                return PARTITIONS_KEY;
            }

            public Intent getIntent() {
                return Intent.NODE;
            }

            public String getTrait(Tree tree, NodeRef node) {
                return getNodePartition(tree, node);
            }
        });

        if (DEBUG) {
            treeTraits.addTrait("NodeNumber", new TreeTrait.S() {
                public String getTraitName() {
                    return "NodeNumber";
                }

                public Intent getIntent() {
                    return Intent.NODE;
                }

                public String getTrait(Tree tree, NodeRef node) {
                    return Integer.toString(node.getNumber());
                }
            });
        }

        likelihoodKnown = false;
    }

    public AbstractOutbreak getOutbreak() {
        return outbreak;
    }

    public boolean hasLatentPeriods() {
        return hasLatentPeriods;
    }

    /* Get the date of the last tip */

    private double getLatestTaxonTime() {
        double latestTime = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
            Taxon taxon = treeModel.getNodeTaxon(treeModel.getExternalNode(i));
            if (taxon.getDate().getTimeValue() > latestTime) {
                latestTime = taxon.getDate().getTimeValue();
            }

        }
        return latestTime;
    }

    private NodeRef[] getChildren(NodeRef node) {
        NodeRef[] children = new NodeRef[treeModel.getChildCount(node)];
        for (int i = 0; i < treeModel.getChildCount(node); i++) {
            children[i] = treeModel.getChild(node, i);
        }
        return children;
    }

    protected void explodeTree() {

        for (int i = 0; i < outbreak.size(); i++) {
            AbstractCase aCase = outbreak.getCase(i);
            if (aCase.wasEverInfected() && elementsAsTrees.get(aCase) == null) {

                NodeRef partitionRoot = ((PartitionedTreeModel) treeModel).getEarliestNodeInElement(aCase);

                double extraHeight;

                if (treeModel.isRoot(partitionRoot)) {
                    extraHeight = maxFirstInfToRoot.getParameterValue(0)
                            * aCase.getInfectionBranchPosition().getParameterValue(0);
                } else {
                    extraHeight = treeModel.getBranchLength(partitionRoot)
                            * aCase.getInfectionBranchPosition().getParameterValue(0);
                }

                FlexibleNode newRoot = new FlexibleNode();

                FlexibleTree littleTree = new FlexibleTree(newRoot);
                littleTree.beginTreeEdit();

                if (!treeModel.isExternal(partitionRoot)) {
                    for (int j = 0; j < treeModel.getChildCount(partitionRoot); j++) {
                        copyElementToTreelet(littleTree, treeModel.getChild(partitionRoot, j), newRoot, aCase);
                    }
                }

                littleTree.endTreeEdit();

                littleTree.resolveTree();

                Treelet treelet = new Treelet(littleTree, littleTree.getRootHeight() + extraHeight);

                elementsAsTrees.put(aCase, treelet);
            }
        }
    }

    private void copyElementToTreelet(FlexibleTree littleTree, NodeRef oldNode, NodeRef newParent,
            AbstractCase element) {
        if (element.wasEverInfected()) {
            if (getBranchMap().get(oldNode.getNumber()) == element) {
                if (treeModel.isExternal(oldNode)) {
                    NodeRef newTip = new FlexibleNode(new Taxon(treeModel.getNodeTaxon(oldNode).getId()));
                    littleTree.addChild(newParent, newTip);
                    littleTree.setBranchLength(newTip, treeModel.getBranchLength(oldNode));
                } else {
                    NodeRef newChild = new FlexibleNode();
                    littleTree.addChild(newParent, newChild);
                    littleTree.setBranchLength(newChild, treeModel.getBranchLength(oldNode));
                    for (int i = 0; i < treeModel.getChildCount(oldNode); i++) {
                        copyElementToTreelet(littleTree, treeModel.getChild(oldNode, i), newChild, element);
                    }
                }
            } else {
                // we need a new tip
                NodeRef transmissionTip = new FlexibleNode(
                        new Taxon("Transmission_" + getBranchMap().get(oldNode.getNumber()).getName()));
                double parentTime = getNodeTime(treeModel.getParent(oldNode));
                double childTime = getInfectionTime(getBranchMap().get(oldNode.getNumber()));
                littleTree.addChild(newParent, transmissionTip);
                littleTree.setBranchLength(transmissionTip, childTime - parentTime);
            }
        }
    }

    protected class Treelet extends FlexibleTree {

        private double zeroHeight;

        protected Treelet(FlexibleTree tree, double zeroHeight) {
            super(tree);
            this.zeroHeight = zeroHeight;

        }

        protected double getZeroHeight() {
            return zeroHeight;
        }

        protected void setZeroHeight(double rootBranchLength) {
            this.zeroHeight = zeroHeight;
        }
    }

    // find all partitions of the descendant tips of the current node. If map is specified then it makes a map of node
    // number to possible partitions; map can be null.

    public HashSet<AbstractCase> descendantTipPartitions(NodeRef node,
            HashMap<Integer, HashSet<AbstractCase>> map) {
        HashSet<AbstractCase> out = new HashSet<AbstractCase>();
        if (treeModel.isExternal(node)) {
            out.add(getBranchMap().get(node.getNumber()));
            if (map != null) {
                map.put(node.getNumber(), out);
            }
            return out;
        } else {
            for (int i = 0; i < treeModel.getChildCount(node); i++) {
                out.addAll(descendantTipPartitions(treeModel.getChild(node, i), map));
            }
            if (map != null) {
                map.put(node.getNumber(), out);
            }
            return out;
        }
    }

    // change flags to indicate that something needs recalculation further down the tree

    protected static void flagForDescendantRecalculation(TreeModel tree, NodeRef node, boolean[] flags) {
        flags[node.getNumber()] = true;
        for (int i = 0; i < tree.getChildCount(node); i++) {
            flags[tree.getChild(node, i).getNumber()] = true;
        }
        NodeRef currentNode = node;
        while (!tree.isRoot(currentNode) && !flags[currentNode.getNumber()]) {
            currentNode = tree.getParent(currentNode);
            flags[currentNode.getNumber()] = true;
        }
    }

    public void flagForDescendantRecalculation(TreeModel tree, NodeRef node) {
        flagForDescendantRecalculation(tree, node, updateNode);
    }

    // **************************************************************
    // ModelListener IMPLEMENTATION
    // **************************************************************

    protected void handleModelChangedEvent(Model model, Object object, int index) {

        if (!(model instanceof AbstractPeriodPriorDistribution)) {

            if (model == treeModel) {

                if (object instanceof PartitionedTreeModel.PartitionsChangedEvent) {
                    HashSet<AbstractCase> changedPartitions = ((PartitionedTreeModel.PartitionsChangedEvent) object)
                            .getCasesToRecalculate();
                    for (AbstractCase aCase : changedPartitions) {
                        recalculateCase(aCase);

                    }
                }
            } else if (model == getBranchMap()) {
                if (object instanceof ArrayList) {

                    for (int i = 0; i < ((ArrayList) object).size(); i++) {
                        BranchMapModel.BranchMapChangedEvent event = (BranchMapModel.BranchMapChangedEvent) ((ArrayList) object)
                                .get(i);

                        recalculateCase(event.getOldCase());
                        recalculateCase(event.getNewCase());

                        NodeRef node = treeModel.getNode(event.getNodeToRecalculate());
                        NodeRef parent = treeModel.getParent(node);

                        if (parent != null) {
                            recalculateCase(getBranchMap().get(parent.getNumber()));
                        }
                    }
                } else {
                    throw new RuntimeException("Unanticipated model changed event from BranchMapModel");
                }
            } else if (model == outbreak) {

                if (object instanceof AbstractCase) {
                    recalculateCase((AbstractCase) object);
                } else {
                    for (AbstractCase aCase : outbreak.getCases()) {
                        recalculateCase(aCase);
                    }
                }
            }

            fireModelChanged(model);

            likelihoodKnown = false;
        }
    }

    protected void recalculateCase(int index) {
        recalculateCaseFlags[index] = true;
    }

    protected void recalculateCase(AbstractCase aCase) {
        if (aCase.wasEverInfected()) {
            recalculateCase(outbreak.getCaseIndex(aCase));
        }
    }

    // **************************************************************
    // VariableListener IMPLEMENTATION
    // **************************************************************

    protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {

        fireModelChanged();

        likelihoodKnown = false;
    }

    // **************************************************************
    // Model IMPLEMENTATION
    // **************************************************************

    /**
     * Stores the precalculated state (in this case the node labels and subtree likelihoods)
     */

    protected void storeState() {
        super.storeState();
        storedInfectionTimes = Arrays.copyOf(infectionTimes, infectionTimes.length);
        storedInfectiousPeriods = Arrays.copyOf(infectiousPeriods, infectiousPeriods.length);
        if (hasLatentPeriods) {
            storedInfectiousTimes = Arrays.copyOf(infectiousTimes, infectionTimes.length);
            storedLatentPeriods = Arrays.copyOf(latentPeriods, latentPeriods.length);
        }
    }

    /**
     * Restores the precalculated state.
     */

    protected void restoreState() {
        super.restoreState();
        infectionTimes = storedInfectionTimes;
        infectiousPeriods = storedInfectiousPeriods;
        if (hasLatentPeriods) {
            infectiousTimes = storedInfectiousTimes;
            latentPeriods = storedLatentPeriods;
        }
    }

    protected final void acceptState() {
    } // nothing to do

    // **************************************************************
    // Likelihood IMPLEMENTATION
    // **************************************************************

    public final BranchMapModel getBranchMap() {
        return ((PartitionedTreeModel) treeModel).getBranchMap();
    }

    public final PartitionedTreeModel getTreeModel() {
        return (PartitionedTreeModel) treeModel;
    }

    public void makeDirty() {
        likelihoodKnown = false;
        Arrays.fill(recalculateCaseFlags, true);
    }

    protected void prepareTimings() {

        infectionTimes = getInfectionTimes(true);

        if (hasLatentPeriods) {
            infectiousTimes = getInfectiousTimes(true);
        }

        infectiousPeriods = getInfectiousPeriods(true);

        if (hasLatentPeriods) {
            latentPeriods = getLatentPeriods(true);
        }

        Arrays.fill(recalculateCaseFlags, false);
    }

    /**
     * Calculates the log likelihood of this set of node labels given the tree.
     */

    protected abstract double calculateLogLikelihood();

    // if no infectious models, just need to check whether any infections occur after the infector was no
    // longer infectious

    protected boolean isAllowed() {
        return isAllowed(treeModel.getRoot());
    }

    private boolean isAllowed(NodeRef node) {
        if (!treeModel.isRoot(node)) {
            AbstractCase childCase = getBranchMap().get(node.getNumber());
            AbstractCase parentCase = getBranchMap().get(treeModel.getParent(node).getNumber());
            if (childCase != parentCase) {
                double infectionTime = infectionTimes[outbreak.getCaseIndex(childCase)];
                if (infectionTime > parentCase.getEndTime() || (hasLatentPeriods
                        && infectionTime < infectiousTimes[outbreak.getCaseIndex(parentCase)])) {
                    return false;
                }
            }
        }
        return treeModel.isExternal(node)
                || (isAllowed(treeModel.getChild(node, 0)) && isAllowed(treeModel.getChild(node, 1)));
    }

    /* Return the double time at which the given node occurred */

    public double getNodeTime(NodeRef node) {
        double nodeHeight = getHeight(node);
        return estimatedLastSampleTime - nodeHeight;
    }

    public double heightToTime(double height) {
        return estimatedLastSampleTime - height;
    }

    public double timeToHeight(double time) {
        return estimatedLastSampleTime - time;
    }

    private double getHeight(NodeRef node) {
        return treeModel.getNodeHeight(node);
    }

    public double getInfectionTime(AbstractCase thisCase) {

        if (!recalculateCaseFlags[outbreak.getCaseIndex(thisCase)]) {
            return infectionTimes[outbreak.getCaseIndex(thisCase)];
        } else {
            if (thisCase.wasEverInfected()) {
                NodeRef child = ((PartitionedTreeModel) treeModel).getEarliestNodeInElement(thisCase);
                NodeRef parent = treeModel.getParent(child);

                if (parent != null) {

                    double min = heightToTime(treeModel.getNodeHeight(parent));

                    // Let the likelihood evaluate to zero due to culling dates if it must...

                    double max = heightToTime(treeModel.getNodeHeight(child));

                    return getInfectionTime(min, max, thisCase);
                } else {
                    return getRootInfectionTime(getBranchMap());
                }
            } else {
                return Double.POSITIVE_INFINITY;
            }
        }
    }

    private double getInfectionTime(double min, double max, AbstractCase infected) {
        final double branchLength = max - min;
        return min + branchLength * (1 - infected.getInfectionBranchPosition().getParameterValue(0));
    }

    public double[] getInfectionTimes(boolean recalculate) {
        if (recalculate) {
            for (int i = 0; i < noCases; i++) {
                if (recalculateCaseFlags[i]) {
                    infectionTimes[i] = getInfectionTime(outbreak.getCase(i));
                }
            }
        }
        return infectionTimes;
    }

    public void setInfectionTime(AbstractCase thisCase, double time) {

        setInfectionHeight(thisCase, timeToHeight(time));

    }

    public void setInfectionHeight(AbstractCase thisCase, double height) {
        if (thisCase.wasEverInfected()) {
            NodeRef child = ((PartitionedTreeModel) treeModel).getEarliestNodeInElement(thisCase);
            NodeRef parent = treeModel.getParent(child);

            double minHeight = treeModel.getNodeHeight(child);
            double maxHeight = parent != null ? treeModel.getNodeHeight(parent)
                    : minHeight + maxFirstInfToRoot.getParameterValue(0);

            if (height < minHeight || height > maxHeight) {
                throw new RuntimeException(
                        "Trying to set an infection time outside the branch on which it must occur");
            }

            double branchPosition = (height - minHeight) / (maxHeight - minHeight);

            thisCase.setInfectionBranchPosition(branchPosition);

        }

    }

    public double getInfectiousTime(AbstractCase thisCase) {
        if (!hasLatentPeriods) {
            return getInfectionTime(thisCase);
        } else {
            if (recalculateCaseFlags[outbreak.getCaseIndex(thisCase)]) {
                if (thisCase.wasEverInfected()) {

                    String latentCategory = ((CategoryOutbreak) outbreak).getLatentCategory(thisCase);
                    Parameter latentPeriod = ((CategoryOutbreak) outbreak).getLatentPeriod(latentCategory);
                    infectiousTimes[outbreak.getCaseIndex(thisCase)] = getInfectionTime(thisCase)
                            + latentPeriod.getParameterValue(0);
                } else {
                    infectiousTimes[outbreak.getCaseIndex(thisCase)] = Double.POSITIVE_INFINITY;
                }
            }
            return infectiousTimes[outbreak.getCaseIndex(thisCase)];
        }

    }

    public double[] getInfectiousTimes(boolean recalculate) {
        if (recalculate) {
            for (int i = 0; i < noCases; i++) {
                if (recalculateCaseFlags[i]) {
                    infectiousTimes[i] = getInfectiousTime(outbreak.getCase(i));
                }
            }
        }
        return infectiousTimes;
    }

    public double getInfectiousPeriod(AbstractCase thisCase) {
        if (recalculateCaseFlags[outbreak.getCaseIndex(thisCase)]) {
            if (thisCase.wasEverInfected()) {

                if (!hasLatentPeriods) {
                    double infectionTime = getInfectionTime(thisCase);
                    double cullTime = thisCase.getEndTime();
                    infectiousPeriods[outbreak.getCaseIndex(thisCase)] = cullTime - infectionTime;
                } else {
                    double infectiousTime = getInfectiousTime(thisCase);
                    double cullTime = thisCase.getEndTime();
                    infectiousPeriods[outbreak.getCaseIndex(thisCase)] = cullTime - infectiousTime;
                }
            } else {
                infectiousPeriods[outbreak.getCaseIndex(thisCase)] = 0;
            }
        }
        return infectiousPeriods[outbreak.getCaseIndex(thisCase)];
    }

    public double[] getInfectiousPeriods(boolean recalculate) {
        if (recalculate) {
            for (int i = 0; i < noCases; i++) {
                if (recalculateCaseFlags[i]) {
                    infectiousPeriods[i] = getInfectiousPeriod(outbreak.getCase(i));
                }
            }
        }
        return infectiousPeriods;
    }

    public Double[] getNonzeroInfectiousPeriods() {
        ArrayList<Double> out = new ArrayList<Double>();

        for (int i = 0; i < noCases; i++) {
            AbstractCase thisCase = outbreak.getCase(i);

            if (thisCase.wasEverInfected()) {
                out.add(getInfectiousPeriod(thisCase));
            }
        }

        return out.toArray(new Double[out.size()]);
    }

    public double getLatentPeriod(AbstractCase thisCase) {
        if (!hasLatentPeriods || !thisCase.wasEverInfected()) {
            return 0.0;
        }
        if (recalculateCaseFlags[outbreak.getCaseIndex(thisCase)]) {
            latentPeriods[outbreak.getCaseIndex(thisCase)] = getInfectiousTime(thisCase)
                    - getInfectionTime(thisCase);
        }
        return latentPeriods[outbreak.getCaseIndex(thisCase)];
    }

    public double[] getLatentPeriods(boolean recalculate) {
        if (recalculate) {
            for (int i = 0; i < noCases; i++) {
                if (recalculateCaseFlags[i]) {
                    latentPeriods[i] = getLatentPeriod(outbreak.getCase(i));
                }
            }
        }
        return latentPeriods;
    }

    public Double[] getNonzeroLatentPeriods() {
        ArrayList<Double> out = new ArrayList<Double>();

        for (int i = 0; i < noCases; i++) {
            AbstractCase thisCase = outbreak.getCase(i);

            if (thisCase.wasEverInfected()) {
                out.add(getLatentPeriod(thisCase));
            }
        }

        return out.toArray(new Double[out.size()]);
    }

    public double[] getInfectedPeriods(boolean recalculate) {
        if (!hasLatentPeriods) {
            return getInfectiousPeriods(recalculate);
        } else {
            double[] out = new double[noCases];
            for (int i = 0; i < noCases; i++) {
                out[i] = getInfectedPeriod(outbreak.getCase(i));
            }
            return out;
        }
    }

    public Double[] getNonzeroInfectedPeriods() {
        ArrayList<Double> out = new ArrayList<Double>();

        for (int i = 0; i < noCases; i++) {
            AbstractCase thisCase = outbreak.getCase(i);

            if (thisCase.wasEverInfected()) {
                out.add(getInfectedPeriod(thisCase));
            }
        }

        return out.toArray(new Double[out.size()]);
    }

    public double getInfectedPeriod(AbstractCase thisCase) {
        if (thisCase.wasEverInfected) {
            return thisCase.getEndTime() - getInfectionTime(thisCase);
        }
        return 0;
    }

    // return an array of the mean, median, variance and standard deviation of the given array
    // @todo this is pretty wasteful since it gets called so many times per log entry

    public static Double[] getSummaryStatistics(Double[] variable) {

        double[] primitiveVariable = new double[variable.length];
        for (int i = 0; i < variable.length; i++) {
            primitiveVariable[i] = variable[i];
        }

        Double[] out = new Double[4];
        out[0] = (new Mean()).evaluate(primitiveVariable);
        out[1] = (new Median()).evaluate(primitiveVariable);
        out[2] = (new Variance()).evaluate(primitiveVariable);
        out[3] = Math.sqrt(out[2]);
        return out;
    }

    private double getRootInfectionTime(BranchMapModel branchMap) {
        NodeRef root = treeModel.getRoot();
        AbstractCase rootCase = branchMap.get(root.getNumber());
        final double branchLength = maxFirstInfToRoot.getParameterValue(0);

        return heightToTime(treeModel.getNodeHeight(root)
                + branchLength * rootCase.getInfectionBranchPosition().getParameterValue(0));

    }

    protected double getRootInfectionTime() {
        AbstractCase rootCase = getBranchMap().get(treeModel.getRoot().getNumber());
        return getInfectionTime(rootCase);
    }

    public void outputTreeToFile(String fileName, boolean includeTransmissionNodes) {
        outputTreeToFile(getBranchMap(), fileName, includeTransmissionNodes);
    }

    public void outputTreeToFile(BranchMapModel map, String fileName, boolean includeTransmissionNodes) {
        try {
            FlexibleTree treeCopy;
            if (!includeTransmissionNodes) {
                treeCopy = new FlexibleTree(treeModel);
                for (int j = 0; j < treeCopy.getNodeCount(); j++) {
                    FlexibleNode node = (FlexibleNode) treeCopy.getNode(j);
                    node.setAttribute("Number", node.getNumber());
                    node.setAttribute("Time", heightToTime(node.getHeight()));
                    node.setAttribute(PARTITIONS_KEY, map.get(node.getNumber()));
                }
            } else {
                treeCopy = addTransmissionNodes(treeModel);
            }
            NexusExporter testTreesOut = new NexusExporter(new PrintStream(fileName));
            testTreesOut.exportTree(treeCopy);
        } catch (IOException ignored) {
            System.out.println("IOException");
        }
    }

    public FlexibleTree addTransmissionNodes(Tree tree) {
        prepareTimings();

        FlexibleTree outTree = new FlexibleTree(tree, true);

        for (int j = 0; j < outTree.getNodeCount(); j++) {
            FlexibleNode node = (FlexibleNode) outTree.getNode(j);
            node.setAttribute("Number", node.getNumber());
            node.setAttribute("Time", heightToTime(node.getHeight()));
            node.setAttribute(PARTITIONS_KEY, getBranchMap().get(node.getNumber()));
        }

        for (AbstractCase aCase : outbreak.getCases()) {
            if (aCase.wasEverInfected()) {
                NodeRef originalNode = ((PartitionedTreeModel) treeModel).getEarliestNodeInElement(aCase);

                int infectionNodeNo = originalNode.getNumber();
                if (!treeModel.isRoot(originalNode)) {
                    NodeRef originalParent = treeModel.getParent(originalNode);
                    double nodeTime = getNodeTime(originalNode);
                    double infectionTime = getInfectionTime(aCase);
                    double heightToBreakBranch = getHeight(originalNode) + (nodeTime - infectionTime);
                    FlexibleNode newNode = (FlexibleNode) outTree.getNode(infectionNodeNo);
                    FlexibleNode oldParent = (FlexibleNode) outTree.getParent(newNode);

                    outTree.beginTreeEdit();
                    outTree.removeChild(oldParent, newNode);
                    FlexibleNode infectionNode = new FlexibleNode();
                    infectionNode.setHeight(heightToBreakBranch);
                    infectionNode.setLength(oldParent.getHeight() - heightToBreakBranch);
                    infectionNode.setAttribute(PARTITIONS_KEY, getNodePartition(treeModel, originalParent));
                    infectionNode.setAttribute("Time", heightToTime(heightToBreakBranch));
                    newNode.setLength(nodeTime - infectionTime);

                    outTree.addChild(oldParent, infectionNode);
                    outTree.addChild(infectionNode, newNode);
                    outTree.endTreeEdit();
                } else {
                    double nodeTime = getNodeTime(originalNode);
                    double infectionTime = getInfectionTime(aCase);
                    double heightToInstallRoot = getHeight(originalNode) + (nodeTime - infectionTime);
                    FlexibleNode newNode = (FlexibleNode) outTree.getNode(infectionNodeNo);
                    outTree.beginTreeEdit();
                    FlexibleNode infectionNode = new FlexibleNode();
                    infectionNode.setHeight(heightToInstallRoot);
                    infectionNode.setAttribute("Time", heightToTime(heightToInstallRoot));
                    infectionNode.setAttribute(PARTITIONS_KEY, "Origin");
                    outTree.addChild(infectionNode, newNode);
                    newNode.setLength(heightToInstallRoot - getHeight(originalNode));
                    outTree.setRoot(infectionNode);
                    outTree.endTreeEdit();
                }
            }
        }

        outTree = new FlexibleTree((FlexibleNode) outTree.getRoot());

        for (int i = 0; i < outTree.getNodeCount(); i++) {
            NodeRef node = outTree.getNode(i);
            NodeRef parent = outTree.getParent(node);
            if (parent != null && outTree.getNodeHeight(node) > outTree.getNodeHeight(parent)) {
                try {
                    NexusExporter exporter = new NexusExporter(new PrintStream("fancyProblem.nex"));
                    exporter.exportTree(outTree);
                } catch (IOException e) {
                    e.printStackTrace();
                }
                try {
                    ((PartitionedTreeModel) treeModel).checkPartitions();
                } catch (BadPartitionException e) {
                    System.out.print("Rewiring messed up because of partition problem.");
                }

                throw new RuntimeException("Rewiring messed up; investigate");
            }

        }

        return outTree;
    }

    //************************************************************************
    // Loggable implementation
    //************************************************************************

    public LogColumn[] getColumns() {
        LogColumn[] columns = new LogColumn[outbreak.infectedSize()];
        int count = 0;
        for (int i = 0; i < outbreak.size(); i++) {
            final AbstractCase infected = outbreak.getCase(i);
            if (infected.wasEverInfected()) {
                columns[count] = new LogColumn.Abstract(infected.toString() + "_infector") {
                    protected String getFormattedValue() {
                        if (((PartitionedTreeModel) treeModel).getInfector(infected) == null) {
                            return "Start";
                        } else {
                            return ((PartitionedTreeModel) treeModel).getInfector(infected).toString();
                        }
                    }
                };
                count++;
            }
        }
        return columns;
    }

    public LogColumn[] passColumns() {
        ArrayList<LogColumn> columns = new ArrayList<LogColumn>();
        for (int i = 0; i < outbreak.size(); i++) {
            final AbstractCase infected = outbreak.getCase(i);
            if (infected.wasEverInfected()) {
                columns.add(new LogColumn.Abstract(infected.toString() + "_infection_date") {
                    protected String getFormattedValue() {
                        return String.valueOf(getInfectionTime(infected));
                    }
                });
            }
        }
        if (hasLatentPeriods) {
            for (int i = 0; i < outbreak.size(); i++) {
                final AbstractCase infected = outbreak.getCase(i);
                if (infected.wasEverInfected()) {
                    columns.add(new LogColumn.Abstract(infected.toString() + "_infectious_date") {
                        protected String getFormattedValue() {
                            return String.valueOf(getInfectiousTime(infected));
                        }
                    });
                }
            }
            for (int i = 0; i < outbreak.size(); i++) {
                final AbstractCase infected = outbreak.getCase(i);
                if (infected.wasEverInfected()) {
                    columns.add(new LogColumn.Abstract(infected.toString() + "_latent_period") {
                        protected String getFormattedValue() {
                            return String.valueOf(getLatentPeriod(infected));
                        }
                    });
                }
            }
        }
        for (int i = 0; i < outbreak.size(); i++) {
            final AbstractCase infected = outbreak.getCase(i);
            if (infected.wasEverInfected()) {
                columns.add(new LogColumn.Abstract(infected.toString() + "_infectious_period") {
                    protected String getFormattedValue() {
                        return String.valueOf(getInfectiousPeriod(infected));
                    }
                });
            }
        }
        if (hasLatentPeriods) {
            for (int i = 0; i < outbreak.size(); i++) {
                final AbstractCase infected = outbreak.getCase(i);
                if (infected.wasEverInfected()) {
                    columns.add(new LogColumn.Abstract(infected.toString() + "_infected_period") {
                        protected String getFormattedValue() {
                            return String.valueOf(getInfectiousPeriod(infected) + getLatentPeriod(infected));
                        }
                    });
                }
            }
        }

        columns.add(new LogColumn.Abstract("infectious_period.mean") {
            protected String getFormattedValue() {
                return String
                        .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroInfectiousPeriods())[0]);
            }
        });
        columns.add(new LogColumn.Abstract("infectious_period.median") {
            protected String getFormattedValue() {
                return String
                        .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroInfectiousPeriods())[1]);
            }
        });
        columns.add(new LogColumn.Abstract("infectious_period.var") {
            protected String getFormattedValue() {
                return String
                        .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroInfectiousPeriods())[2]);
            }
        });
        columns.add(new LogColumn.Abstract("infectious_period.stdev") {
            protected String getFormattedValue() {
                return String
                        .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroInfectiousPeriods())[3]);
            }
        });
        if (hasLatentPeriods) {
            columns.add(new LogColumn.Abstract("latent_period.mean") {
                protected String getFormattedValue() {
                    return String
                            .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroLatentPeriods())[0]);
                }
            });
            columns.add(new LogColumn.Abstract("latent_period.median") {
                protected String getFormattedValue() {
                    return String
                            .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroLatentPeriods())[1]);
                }
            });
            columns.add(new LogColumn.Abstract("latent_period.var") {
                protected String getFormattedValue() {
                    return String
                            .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroLatentPeriods())[2]);
                }
            });
            columns.add(new LogColumn.Abstract("latent_period.stdev") {
                protected String getFormattedValue() {
                    return String
                            .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroLatentPeriods())[3]);
                }
            });
            columns.add(new LogColumn.Abstract("infected_period.mean") {
                protected String getFormattedValue() {
                    return String
                            .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroInfectedPeriods())[0]);
                }
            });
            columns.add(new LogColumn.Abstract("infected_period.median") {
                protected String getFormattedValue() {
                    return String
                            .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroInfectedPeriods())[1]);
                }
            });
            columns.add(new LogColumn.Abstract("infected_period.var") {
                protected String getFormattedValue() {
                    return String
                            .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroInfectedPeriods())[2]);
                }
            });
            columns.add(new LogColumn.Abstract("infected_period.stdev") {
                protected String getFormattedValue() {
                    return String
                            .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroInfectedPeriods())[3]);
                }
            });
            for (int i = 0; i < outbreak.size(); i++) {
                final AbstractCase infected = outbreak.getCase(i);
                if (infected.wasEverInfected()) {
                    columns.add(new LogColumn.Abstract(infected.toString() + "_ibp") {
                        protected String getFormattedValue() {
                            return String.valueOf(infected.getInfectionBranchPosition().getParameterValue(0));
                        }
                    });
                }
            }
        }

        return columns.toArray(new LogColumn[columns.size()]);

    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.TREE_PRIORS;
    }

    @Override
    public String getDescription() {
        return "Case to Case Transmission Tree model";
    }

    public List<Citation> getCitations() {
        return Arrays.asList(new Citation(
                new Author[] { new Author("M", "Hall"), new Author("M", "Woolhouse"), new Author("A", "Rambaut") },
                "Epidemic Reconstruction in a Phylogenetics Framework: Transmission Trees as Partitions of the Node Set",
                2016, "PLOS Comput Biol", 11, 0, 0, "10.1371/journal.pcbi.1004613", Citation.Status.PUBLISHED));
    }

    // **************************************************************
    // TreeTraitProvider IMPLEMENTATION
    // **************************************************************

    public TreeTrait[] getTreeTraits() {
        return treeTraits.getTreeTraits();
    }

    public TreeTrait getTreeTrait(String key) {
        return treeTraits.getTreeTrait(key);
    }

    public String getNodePartition(Tree tree, NodeRef node) {
        if (tree != treeModel) {
            // we're trying to annotate a partitioned tree, we hope
            try {
                NodeRef oldNode = treeModel.getNode((Integer) tree.getNodeAttribute(node, "Number"));
                if (treeModel.getNodeHeight(oldNode) != tree.getNodeHeight(node)) {
                    throw new RuntimeException("Can only reconstruct states on treeModel given to constructor or a "
                            + "partitioned tree derived from it");
                } else {
                    return getBranchMap().get(oldNode.getNumber()).toString();
                }
            } catch (NullPointerException e) {
                if (tree.isRoot(node)) {
                    return "Start";
                } else {
                    NodeRef parent = tree.getParent(node);
                    int originalParentNumber = (Integer) tree.getNodeAttribute(parent, "Number");
                    return getBranchMap().get(originalParentNumber).toString();
                }
            }
        } else {
            return getBranchMap().get(node.getNumber()).toString();
        }
    }

    public Integer[] getParentsArray() {
        Integer[] out = new Integer[outbreak.size()];
        for (AbstractCase thisCase : outbreak.getCases()) {
            if (thisCase.wasEverInfected()) {
                out[outbreak.getCaseIndex(thisCase)] = outbreak
                        .getCaseIndex(((PartitionedTreeModel) treeModel).getInfector(thisCase));
            } else {
                out[outbreak.getCaseIndex(thisCase)] = null;
            }
        }
        return out;
    }

    public AbstractCase getInfector(int i) {
        return ((PartitionedTreeModel) treeModel).getInfector(getOutbreak().getCase(i));
    }

}