multitypetree.distributions.StructuredCoalescentTreeDensity.java Source code

Java tutorial

Introduction

Here is the source code for multitypetree.distributions.StructuredCoalescentTreeDensity.java

Source

/*
 * Copyright (C) 2012 Tim Vaughan
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package multitypetree.distributions;

import beast.core.*;
import beast.core.Input.Validate;
import beast.core.parameter.RealParameter;
import beast.evolution.tree.MigrationModel;
import beast.evolution.tree.MultiTypeNode;
import beast.evolution.tree.MultiTypeTree;
import beast.evolution.tree.MultiTypeTreeFromNewick;
import beast.evolution.tree.Node;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 *
 * @author Tim Vaughan
 */
@Description("Likelihood of ColouredTree under structured coalescent.")
public class StructuredCoalescentTreeDensity extends MultiTypeTreeDistribution {

    public Input<MigrationModel> migrationModelInput = new Input<>("migrationModel",
            "Model of migration between demes.", Validate.REQUIRED);
    public Input<Boolean> checkValidityInput = new Input<>("checkValidity",
            "Explicitly check validity of colouring.  " + "(Default false.)  "
                    + "Useful if operators are in danger of proposing invalid trees.",
            false);

    protected MigrationModel migrationModel;
    protected MultiTypeTree mtTree;
    protected boolean checkValidity;

    protected enum SCEventKind {
        COALESCE, MIGRATE, SAMPLE
    }

    protected class SCEvent {
        double time;
        int type, destType;
        SCEventKind kind;
        Node node;
    }

    protected List<SCEvent> eventList;
    protected List<Integer[]> lineageCountList;

    // Empty constructor as required:
    public StructuredCoalescentTreeDensity() {

        super();

    }

    @Override
    public void initAndValidate() {

        migrationModel = migrationModelInput.get();
        mtTree = mtTreeInput.get();
        checkValidity = checkValidityInput.get();

        eventList = new ArrayList<>();
        lineageCountList = new ArrayList<>();
    }

    @Override
    public double calculateLogP() {

        // Check validity of tree if required:
        if (checkValidity && !mtTree.isValid())
            return Double.NEGATIVE_INFINITY;

        // Ensure sequence of events is up-to-date:
        updateEventSequence();

        // Start from the tips of the tree, working up.
        logP = 0;

        // Note that the first event is always a sample. We begin at the first
        // _interval_ and the event following that interval.
        for (int eventIdx = 1; eventIdx < eventList.size(); eventIdx++) {

            SCEvent event = eventList.get(eventIdx);
            Integer[] lineageCount = lineageCountList.get(eventIdx);
            double delta_t = event.time - eventList.get(eventIdx - 1).time;

            // Interval contribution:
            if (delta_t > 0) {
                double lambda = 0.0;
                for (int c = 0; c < lineageCount.length; c++) {
                    int k = lineageCount[c];
                    double Nc = migrationModel.getPopSize(c);
                    lambda += k * (k - 1) / (2.0 * Nc);

                    for (int cp = 0; cp < lineageCount.length; cp++) {
                        if (cp == c)
                            continue;

                        double m = migrationModel.getRate(c, cp);
                        lambda += k * m;
                    }
                }
                logP += -delta_t * lambda;
            }

            // Event contribution:
            switch (event.kind) {
            case COALESCE:
                double N = migrationModel.getPopSize(event.type);
                logP += Math.log(1.0 / N);
                break;

            case MIGRATE:
                double m = migrationModel.getRate(event.type, event.destType);
                logP += Math.log(m);
                break;

            case SAMPLE:
                // Do nothing here: only effect of sampling event is
                // to change the lineage counts in subsequent intervals.
                break;
            }
        }

        return logP;
    }

    /**
     * Determines the sequence of migration, coalescence and sampling events
     * which make up the coloured tree.
     */
    public void updateEventSequence() {

        // Clean up previous list:
        eventList.clear();
        lineageCountList.clear();
        Node rootNode = mtTree.getRoot();

        // Initialise map of active nodes to active change indices:
        Map<Node, Integer> changeIdx = new HashMap<>();
        changeIdx.put(rootNode, -1);

        // Initialise lineage count per colour array:
        Integer[] lineageCount = new Integer[migrationModel.getNTypes()];
        for (int c = 0; c < migrationModel.getNTypes(); c++)
            if (c == ((MultiTypeNode) rootNode).getNodeType())
                lineageCount[c] = 1;
            else
                lineageCount[c] = 0;

        // Calculate event sequence:
        while (!changeIdx.isEmpty()) {

            SCEvent nextEvent = new SCEvent();
            nextEvent.time = Double.NEGATIVE_INFINITY;
            nextEvent.node = rootNode; // Initial assignment not significant

            // Determine next event
            for (Node node : changeIdx.keySet())
                if (changeIdx.get(node) < 0) {
                    if (node.isLeaf()) {
                        // Next event is a sample
                        if (node.getHeight() > nextEvent.time) {
                            nextEvent.time = node.getHeight();
                            nextEvent.kind = SCEventKind.SAMPLE;
                            nextEvent.type = ((MultiTypeNode) node).getNodeType();
                            nextEvent.node = node;
                        }
                    } else {
                        // Next event is a coalescence
                        if (node.getHeight() > nextEvent.time) {
                            nextEvent.time = node.getHeight();
                            nextEvent.kind = SCEventKind.COALESCE;
                            nextEvent.type = ((MultiTypeNode) node).getNodeType();
                            nextEvent.node = node;
                        }
                    }
                } else {
                    // Next event is a migration
                    double thisChangeTime = ((MultiTypeNode) node).getChangeTime(changeIdx.get(node));
                    if (thisChangeTime > nextEvent.time) {
                        nextEvent.time = thisChangeTime;
                        nextEvent.kind = SCEventKind.MIGRATE;
                        nextEvent.destType = ((MultiTypeNode) node).getChangeType(changeIdx.get(node));
                        if (changeIdx.get(node) > 0)
                            nextEvent.type = ((MultiTypeNode) node).getChangeType(changeIdx.get(node) - 1);
                        else
                            nextEvent.type = ((MultiTypeNode) node).getNodeType();
                        nextEvent.node = node;
                    }
                }

            // Update active node list (changeIdx) and lineage count appropriately:
            switch (nextEvent.kind) {
            case COALESCE:
                Node leftChild = nextEvent.node.getLeft();
                Node rightChild = nextEvent.node.getRight();

                changeIdx.remove(nextEvent.node);
                changeIdx.put(leftChild, ((MultiTypeNode) leftChild).getChangeCount() - 1);
                changeIdx.put(rightChild, ((MultiTypeNode) rightChild).getChangeCount() - 1);
                lineageCount[nextEvent.type]++;
                break;

            case SAMPLE:
                changeIdx.remove(nextEvent.node);
                lineageCount[nextEvent.type]--;
                break;

            case MIGRATE:
                lineageCount[nextEvent.destType]--;
                lineageCount[nextEvent.type]++;
                int oldIdx = changeIdx.get(nextEvent.node);
                changeIdx.put(nextEvent.node, oldIdx - 1);
                break;
            }

            // Add event to list:
            eventList.add(nextEvent);
            lineageCountList.add(Arrays.copyOf(lineageCount, lineageCount.length));
        }

        // Reverse event and lineage count lists (order them from tips to root):
        eventList = Lists.reverse(eventList);
        lineageCountList = Lists.reverse(lineageCountList);

    }

    @Override
    public boolean requiresRecalculation() {
        return true;
    }

    /**
     * Test likelihood result. Duplicate of JUnit test for debugging purposes.
     *
     * @param argv
     * @throws java.lang.Exception
     */
    public static void main(String[] argv) throws Exception {

        // Assemble test MultiTypeTree:
        String newickStr = "(((A[state=1]:0.25)[state=0]:0.25,B[state=0]:0.5)[state=0]:1.5,"
                + "(C[state=0]:1.0,D[state=0]:1.0)[state=0]:1.0)[state=0]:0.0;";

        MultiTypeTreeFromNewick mtTree = new MultiTypeTreeFromNewick();
        mtTree.initByName("newick", newickStr, "typeLabel", "state", "nTypes", 2);

        // Assemble migration model:
        RealParameter rateMatrix = new RealParameter();
        rateMatrix.initByName("value", "2.0 1.0");
        RealParameter popSizes = new RealParameter();
        popSizes.initByName("value", "5.0 10.0");
        MigrationModel migrationModel = new MigrationModel();
        migrationModel.initByName("rateMatrix", rateMatrix, "popSizes", popSizes);

        // Set up likelihood instance:
        StructuredCoalescentTreeDensity likelihood = new StructuredCoalescentTreeDensity();
        likelihood.initByName("migrationModel", migrationModel, "multiTypeTree", mtTree);

        double expResult = -16.52831; // Calculated by hand
        double result = likelihood.calculateLogP();

        System.out.println("Expected result: " + expResult);
        System.out.println("Actual result: " + result);
        System.out.println("Difference: " + String.valueOf(result - expResult));

    }
}