org.drugis.mtc.jags.JagsSyntaxModel.java Source code

Java tutorial

Introduction

Here is the source code for org.drugis.mtc.jags.JagsSyntaxModel.java

Source

/*
 * This file is part of the GeMTC software for MTC model generation and
 * analysis. GeMTC is distributed from http://drugis.org/gemtc.
 * Copyright (C) 2009-2012 Gert van Valkenhoef.
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

package org.drugis.mtc.jags;

import java.text.DecimalFormat;
import java.text.Format;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

import org.apache.commons.collections15.CollectionUtils;
import org.apache.commons.collections15.PredicateUtils;
import org.apache.commons.collections15.Transformer;
import org.apache.commons.lang.StringUtils;
import org.drugis.common.CollectionHelper;
import org.drugis.mtc.data.DataType;
import org.drugis.mtc.model.Measurement;
import org.drugis.mtc.model.Network;
import org.drugis.mtc.model.Study;
import org.drugis.mtc.model.Treatment;
import org.drugis.mtc.parameterization.BasicParameter;
import org.drugis.mtc.parameterization.InconsistencyParameter;
import org.drugis.mtc.parameterization.InconsistencyParameterization;
import org.drugis.mtc.parameterization.InconsistencyStartingValueGenerator;
import org.drugis.mtc.parameterization.NetworkModel;
import org.drugis.mtc.parameterization.NetworkParameter;
import org.drugis.mtc.parameterization.NodeSplitParameterization;
import org.drugis.mtc.parameterization.ParameterComparator;
import org.drugis.mtc.parameterization.Parameterization;
import org.drugis.mtc.parameterization.PriorGenerator;
import org.drugis.mtc.parameterization.SplitParameter;
import org.drugis.mtc.parameterization.StartingValueGenerator;
import org.mvel2.templates.CompiledTemplate;
import org.mvel2.templates.TemplateCompiler;
import org.mvel2.templates.TemplateRuntime;

import com.jgoodies.binding.list.ObservableList;

import edu.uci.ics.jung.graph.util.Pair;

public class JagsSyntaxModel {
    private static final Format s_format = new DecimalFormat("0.0##E0");
    private static final Transformer<NetworkParameter, String> s_idTrans = new Transformer<NetworkParameter, String>() {
        public String transform(NetworkParameter input) {
            return input.getName();
        }
    };
    private static final Transformer<NetworkParameter, String> s_rTransform = new Transformer<NetworkParameter, String>() {
        public String transform(NetworkParameter input) {
            return "x[, \"" + input.getName() + "\"]";
        }
    };

    private final Parameterization d_pmtz;
    private final boolean d_isJags;
    private final boolean d_inconsistency;
    private final boolean d_nodeSplit;
    private final Network d_network;
    private final PriorGenerator d_priorGen;

    public JagsSyntaxModel(Network network, Parameterization pmtz, boolean isJags) {
        d_network = network;
        d_pmtz = pmtz;
        d_isJags = isJags;
        d_inconsistency = pmtz instanceof InconsistencyParameterization;
        d_nodeSplit = pmtz instanceof NodeSplitParameterization;
        d_priorGen = new PriorGenerator(network);
    }

    /**
     * Rewrite a number in scientific E-notation to the format appropriate for BUGS or JAGS.
     */
    private String rewriteNumber(String s) {
        return d_isJags ? s.replaceFirst("E", "*10^") : s;
    }

    private String generateDataFile(List<Pair<String>> assignments) {
        final String assign = d_isJags ? " <- " : " = ";
        final String sep = d_isJags ? "\n" : ",\n";
        final String head = d_isJags ? "" : "list(\n";
        final String foot = d_isJags ? "\n" : "\n)\n";

        Collection<String> lines = CollectionHelper.transform(assignments, new Transformer<Pair<String>, String>() {
            public String transform(Pair<String> input) {
                return input.getFirst() + assign + input.getSecond();
            }
        });

        return head + StringUtils.join(lines, sep) + foot;
    }

    public String initialValuesText(StartingValueGenerator generator) {
        List<Pair<String>> list = new ArrayList<Pair<String>>();
        list.addAll(initMetaParameters(generator));
        list.addAll(initBaselineEffects(generator));
        list.addAll(initRelativeEffects(generator));
        list.addAll(initVarianceParameters(generator));

        return generateDataFile(list);
    }

    public String analysisText(String prefix) {
        List<String> list = new ArrayList<String>();
        list.add("deriv <- list(");
        list.add(getDerivations());
        list.add("\t)");
        list.add("# source('mtc.R')");
        list.add("# data <- append.derived(read.mtc('" + prefix + "'), deriv)\n");

        return StringUtils.join(list, "\n");
    }

    private List<Pair<String>> initMetaParameters(StartingValueGenerator generator) {
        Map<BasicParameter, Double> basicValues = new HashMap<BasicParameter, Double>();
        List<Pair<String>> list = new ArrayList<Pair<String>>();
        for (NetworkParameter p : d_pmtz.getParameters()) {
            double relativeEffect;
            if (p instanceof BasicParameter) {
                relativeEffect = generator.getRelativeEffect((BasicParameter) p);
                basicValues.put((BasicParameter) p, relativeEffect);
            } else if (p instanceof SplitParameter) {
                SplitParameter sp = (SplitParameter) p;
                relativeEffect = generator.getRelativeEffect(new BasicParameter(sp.getBaseline(), sp.getSubject()));
            } else if (p instanceof InconsistencyParameter) {
                relativeEffect = InconsistencyStartingValueGenerator.generate((InconsistencyParameter) p,
                        (InconsistencyParameterization) d_pmtz, generator, basicValues);
            } else {
                throw new IllegalStateException(
                        "Unhandled parameter " + p + " of type " + p.getClass().getCanonicalName());
            }
            list.add(new Pair<String>(p.getName(), String.valueOf(relativeEffect)));
        }
        return list;
    }

    private List<Pair<String>> initBaselineEffects(StartingValueGenerator generator) {
        Double[] baselineEffects = new Double[d_network.getStudies().size()];
        for (int i = 0; i < d_network.getStudies().size(); ++i) {
            final Study study = d_network.getStudies().get(i);
            final Treatment baseline = d_pmtz.getStudyBaseline(study);
            baselineEffects[i] = generator.getTreatmentEffect(study, baseline);
        }
        return Collections.singletonList(new Pair<String>("mu", writeVector(baselineEffects, d_isJags)));
    }

    private List<Pair<String>> initRelativeEffects(final StartingValueGenerator generator) {
        Double[][] relativeEffects = getDoubleMatrix(new StudyTreatmentTransformer<Double>() {
            public Double transform(final Study s, final Treatment t) {
                final Treatment b = d_pmtz.getStudyBaseline(s);
                if (b.equals(t)) {
                    return null;
                }
                return generator.getRelativeEffect(s, new BasicParameter(b, t));
            }
        });
        return Collections.singletonList(new Pair<String>("delta", writeMatrix(relativeEffects, d_isJags)));
    }

    private List<Pair<String>> initVarianceParameters(StartingValueGenerator generator) {
        List<Pair<String>> list = new ArrayList<Pair<String>>();
        list.add(new Pair<String>("sd.d", String.valueOf(generator.getStandardDeviation())));
        if (d_inconsistency) {
            list.add(new Pair<String>("sd.w", String.valueOf(generator.getStandardDeviation())));
        }
        return list;
    }

    private String getDerivations() {
        Collection<String> lines = NetworkModel.transformTreatmentPairs(d_network,
                new Transformer<Pair<Treatment>, String>() {
                    public String transform(Pair<Treatment> input) {
                        Treatment ti = input.getFirst();
                        Treatment tj = input.getSecond();
                        BasicParameter p = new BasicParameter(ti, tj);
                        BasicParameter q = new BasicParameter(tj, ti);
                        if (!d_pmtz.getParameters().contains(p) && !d_pmtz.getParameters().contains(q)) {
                            String e = expressRelativeEffect(ti, tj, s_rTransform);
                            return "\t`" + p + "` = function(x) { " + e + " }";
                        }
                        return null;
                    }
                });
        CollectionUtils.filter(lines, PredicateUtils.notNullPredicate());
        return StringUtils.join(lines, ",\n");
    }

    public CompiledTemplate readTemplate(String path) {
        return TemplateCompiler.compileTemplate(getClass().getResourceAsStream(path));
    }

    public String modelText() {
        CompiledTemplate template = readTemplate("modelTemplate.txt");
        HashMap<String, Object> map = new HashMap<String, Object>();
        map.put("dichotomous", d_network.getType().equals(DataType.RATE));
        map.put("nodeSplit", d_nodeSplit);
        if (d_nodeSplit) {
            map.put("indirectNode", getIndirectEvidenceExpression());
        }
        map.put("inconsistency", d_inconsistency);
        map.put("relativeEffectMatrix", getRelativeEffectMatrix());
        double sd = d_priorGen.getVagueNormalSigma();
        map.put("priorPrecision", rewriteNumber(s_format.format(1 / (sd * sd))));
        map.put("stdDevUpperLimit", rewriteNumber(s_format.format(d_priorGen.getRandomEffectsSigma())));
        map.put("parameters", d_pmtz.getParameters());
        map.put("inconsClass", InconsistencyParameter.class);
        return String.valueOf(TemplateRuntime.execute(template, map));
    }

    private String getIndirectEvidenceExpression() {
        NodeSplitParameterization pmtz = (NodeSplitParameterization) d_pmtz;
        return pmtz.getIndirectParameter().toString() + " <- "
                + writeExpression(pmtz.parameterizeIndirect(), s_idTrans);
    }

    private String expressRelativeEffect(Treatment t1, Treatment t2,
            Transformer<NetworkParameter, String> transform) {
        if (t1.equals(t2)) {
            return "0";
        }
        return writeExpression(d_pmtz.parameterize(t1, t2), transform);
    }

    private String getRelativeEffectMatrix() {
        List<String> lines = new ArrayList<String>();
        final ObservableList<Treatment> treatments = d_network.getTreatments();
        for (int i = 0; i < treatments.size(); ++i) {
            for (int j = 0; j < treatments.size(); ++j) {
                lines.add("\td[" + (i + 1) + "," + (j + 1) + "] <- "
                        + expressRelativeEffect(treatments.get(i), treatments.get(j), s_idTrans));
            }
        }
        return StringUtils.join(lines, "\n");
    }

    public String scriptText(String prefix, int nchains, int tuning, int simulation) {
        CompiledTemplate template = d_isJags ? readTemplate("jagsScriptTemplate.txt")
                : readTemplate("bugsScriptTemplate.txt");
        Map<String, Object> map = new HashMap<String, Object>();
        map.put("prefix", prefix);
        map.put("nchains", nchains);
        List<Integer> chains = new ArrayList<Integer>();
        for (int i = 0; i < nchains; ++i) {
            chains.add(i + 1);
        }
        map.put("chains", chains);
        map.put("tuning", tuning);
        map.put("simulation", simulation);
        map.put("inconsistency", d_inconsistency);
        final List<NetworkParameter> parameters = new ArrayList<NetworkParameter>(d_pmtz.getParameters());
        if (d_nodeSplit) {
            NodeSplitParameterization pmtz = (NodeSplitParameterization) d_pmtz;
            parameters.add(pmtz.getIndirectParameter());
        }
        map.put("parameters", parameters);
        return String.valueOf(TemplateRuntime.execute(template, map));
    }

    public String dataText() {
        List<Pair<String>> list = new ArrayList<Pair<String>>();
        if (d_nodeSplit) {
            Treatment baseline = ((NodeSplitParameterization) d_pmtz).getDirectParameter().getBaseline();
            Treatment subject = ((NodeSplitParameterization) d_pmtz).getDirectParameter().getSubject();
            Integer[] split = new Integer[] { d_network.getTreatments().indexOf(baseline) + 1,
                    d_network.getTreatments().indexOf(subject) + 1 };
            list.add(new Pair<String>("split", writeVector(split, d_isJags)));
        }
        list.add(new Pair<String>("ns", writeNumber(d_network.getStudies().size(), d_isJags)));
        list.add(new Pair<String>("na", writeVector(getArmCounts(), d_isJags)));
        list.add(new Pair<String>("t", writeMatrix(getTreatmentMatrix(), d_isJags)));

        switch (d_network.getType()) {
        case RATE:
            list.add(new Pair<String>("r", writeMatrix(getResponderMatrix(), d_isJags)));
            list.add(new Pair<String>("n", writeMatrix(getSampleSizeMatrix(), d_isJags)));
            break;
        case CONTINUOUS:
            list.add(new Pair<String>("m", writeMatrix(getMeanMatrix(), d_isJags)));
            list.add(new Pair<String>("e", writeMatrix(getStdErrMatrix(), d_isJags)));
            break;
        default:
            throw new IllegalArgumentException(
                    "Don't know how to generate starting values for " + d_network.getType() + " data");
        }

        return generateDataFile(list);
    }

    private Integer[] getArmCounts() {
        Integer count[] = new Integer[d_network.getStudies().size()];
        for (int i = 0; i < count.length; ++i) {
            count[i] = d_network.getStudies().get(i).getMeasurements().size();
        }
        return count;
    }

    public int getMaxArmCount() {
        int max = 0;
        for (int i = 0; i < d_network.getStudies().size(); ++i) {
            max = Math.max(max, d_network.getStudies().get(i).getMeasurements().size());
        }
        return max;
    }

    public Integer[][] getTreatmentMatrix() {
        return getIntegerMatrix(new StudyTreatmentTransformer<Integer>() {
            public Integer transform(Study s, Treatment t) {
                return d_network.getTreatments().indexOf(t) + 1;
            }
        });
    }

    public Integer[][] getResponderMatrix() {
        return getIntegerMatrix(new StudyTreatmentTransformer<Integer>() {
            public Integer transform(Study s, Treatment t) {
                return NetworkModel.findMeasurement(s, t).getResponders();
            }
        });
    }

    public Double[][] getMeanMatrix() {
        return getDoubleMatrix(new StudyTreatmentTransformer<Double>() {
            public Double transform(Study s, Treatment t) {
                return NetworkModel.findMeasurement(s, t).getMean();
            }
        });
    }

    public Double[][] getStdErrMatrix() {
        return getDoubleMatrix(new StudyTreatmentTransformer<Double>() {
            public Double transform(Study s, Treatment t) {
                final Measurement m = NetworkModel.findMeasurement(s, t);
                return m.getStdDev() / Math.sqrt(m.getSampleSize());
            }
        });
    }

    public Integer[][] getSampleSizeMatrix() {
        return getIntegerMatrix(new StudyTreatmentTransformer<Integer>() {
            public Integer transform(Study s, Treatment t) {
                return NetworkModel.findMeasurement(s, t).getSampleSize();
            }
        });
    }

    private interface StudyTreatmentTransformer<O> {
        public O transform(Study s, Treatment t);
    };

    /**
     * Fill a matrix with Doubles that correspond to each (study, arm) combination.
     * @see getTreatments() for the order in which arms are listed. 
     * @param transformer Takes a pair of (Study, Treatment) and gives the corresponding number.
     * @return The generated matrix, containing "null" for missing combinations.
     */
    public Double[][] getDoubleMatrix(StudyTreatmentTransformer<Double> transformer) {
        Double[][] m = new Double[d_network.getStudies().size()][getMaxArmCount()];
        getMatrix(m, transformer);
        return m;
    }

    /**
     * Fill a matrix with Integers that correspond to each (study, arm) combination.
     * @see getTreatments() for the order in which arms are listed.
     * @param transformer Takes a pair of (Study, Treatment) and gives the corresponding number.
     * @return The generated matrix, containing "null" for missing combinations.
     */
    public Integer[][] getIntegerMatrix(StudyTreatmentTransformer<Integer> transformer) {
        Integer[][] m = new Integer[d_network.getStudies().size()][getMaxArmCount()];
        getMatrix(m, transformer);
        return m;
    }

    /**
     * Fill a matrix with numbers that correspond to each (study, arm) combination.
     * @see getTreatments() for the order in which arms are listed.
     * @param <N> The type of number.
     * @param m The array to write results in.
     * @param transformer Takes a pair of (Study, Treatment) and gives the corresponding number.
     */
    public <N extends Number> void getMatrix(N[][] m, StudyTreatmentTransformer<N> transformer) {
        final ObservableList<Study> studies = d_network.getStudies();
        for (int i = 0; i < studies.size(); ++i) {
            final List<Treatment> treatments = getTreatments(studies.get(i));
            for (int j = 0; j < treatments.size(); ++j) {
                m[i][j] = transformer.transform(studies.get(i), treatments.get(j));
            }
        }
    }

    /**
     * Get the treatments of the study, giving the baseline first, then the rest according to the natural ordering.
     */
    private List<Treatment> getTreatments(Study study) {
        List<Treatment> treatments = NetworkModel.getTreatments(study);
        Treatment baseline = d_pmtz.getStudyBaseline(study);
        treatments.remove(baseline);
        treatments.add(0, baseline);
        return treatments;
    }

    /**
     * Convert a number to a String so that it can be read by S-Plus/R
     * @param x the number to convert to a string.
     * @param jags true to write for JAGS/R, false to write for BUGS.
     */
    public static String writeNumber(Number x, boolean jags) {
        if (x == null) {
            return "NA";
        }
        String suffix = jags && isInteger(x) ? "L" : "";
        return String.valueOf(x) + suffix;
    }

    private static boolean isInteger(Number x) {
        return x instanceof Integer || x instanceof Long || x instanceof Short || x instanceof Byte;
    }

    /**
     * Convert a matrix m -- where m(i)(j) is the number in the i-th row and j-th column -- to S-Plus/R format.
     * @param jags true for column-major format (R/S-Plus/JAGS), false for row-major (BUGS).
     */
    public static String writeMatrix(Number[][] m, boolean jags) {
        int rows = m.length;
        int cols = m[0].length;
        String cells[] = new String[rows * cols];
        for (int i = 0; i < cells.length; ++i) {
            cells[i] = writeNumber(jags ? m[i % rows][i / rows] : m[i / cols][i % cols], jags);
        }

        return "structure(" + (jags ? "" : ".Data = ") + "c(" + StringUtils.join(cells, ", ") + "), .Dim = c("
                + writeNumber(rows, jags) + ", " + writeNumber(cols, jags) + "))";
    }

    /**
     * Convert a vector v to S-Plus/R format.
     * @param jags true for R/S-Plus/JAGS format, false for BUGS.
     */
    public static String writeVector(Number[] v, boolean jags) {
        String[] cells = new String[v.length];
        for (int i = 0; i < cells.length; ++i) {
            cells[i] = writeNumber(v[i], jags);
        }

        return "c(" + StringUtils.join(cells, ", ") + ")";
    }

    /**
     * Transform the given map from parameters to -1 or +1 to a sum expression.
     */
    public static String writeExpression(Map<NetworkParameter, Integer> pmtz,
            Transformer<NetworkParameter, String> transform) {
        List<String> terms = new ArrayList<String>();
        final Set<NetworkParameter> keys = new TreeSet<NetworkParameter>(new ParameterComparator());
        keys.addAll(pmtz.keySet());
        for (NetworkParameter key : keys) {
            terms.add((pmtz.get(key) == -1 ? "-" : "") + transform.transform(key));
        }
        return StringUtils.join(terms, " + ");
    }
}