edu.msu.cme.rdp.classifier.train.validation.distance.TaxaSimilarityMain.java Source code

Java tutorial

Introduction

Here is the source code for edu.msu.cme.rdp.classifier.train.validation.distance.TaxaSimilarityMain.java

Source

/*
 * Copyright (C) 2014 wangqion
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 */

package edu.msu.cme.rdp.classifier.train.validation.distance;

import edu.msu.cme.rdp.alignment.AlignmentMode;
import edu.msu.cme.rdp.alignment.pairwise.PairwiseAligner;
import edu.msu.cme.rdp.alignment.pairwise.PairwiseAlignment;
import edu.msu.cme.rdp.alignment.pairwise.ScoringMatrix;
import edu.msu.cme.rdp.alignment.pairwise.rna.DistanceModel;
import edu.msu.cme.rdp.alignment.pairwise.rna.IdentityDistanceModel;
import edu.msu.cme.rdp.alignment.pairwise.rna.OverlapCheckFailedException;
import edu.msu.cme.rdp.classifier.train.LineageSequence;
import edu.msu.cme.rdp.classifier.train.LineageSequenceParser;
import edu.msu.cme.rdp.classifier.train.validation.HierarchyTree;
import edu.msu.cme.rdp.classifier.train.validation.TreeFactory;
import edu.msu.cme.rdp.readseq.utils.kmermatch.KmerMatchCore;
import edu.msu.cme.rdp.readseq.utils.kmermatch.NuclSeqMatch;
import edu.msu.cme.rdp.readseq.utils.orientation.GoodWordIterator;
import java.awt.BasicStroke;
import java.awt.Font;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.TreeSet;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.axis.NumberTickUnit;
import org.jfree.chart.axis.ValueAxis;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.data.statistics.BoxAndWhiskerItem;
import org.jfree.data.statistics.DefaultBoxAndWhiskerCategoryDataset;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;

/**
 *
 * @author wangqion
 */
public class TaxaSimilarityMain {

    public static String[] RANKS = { "norank", "domain", "phylum", "class", "order", "family", "genus" };
    private ArrayList<Short> withinLowestRankSabSet = new ArrayList<Short>();
    private ArrayList<Short> diffLowestRankSabSet = new ArrayList<Short>();
    private List<String> ranks = new ArrayList<String>();
    private DecimalFormat format = new DecimalFormat("#.###");
    private HashMap<String, long[]> sabCoutMap = new HashMap<String, long[]>(); // key = rank, value, count of the sab scores
    private final int BINSIZE = 101;
    private ScoringMatrix scoringMatrix = ScoringMatrix.getDefaultNuclMatrix();
    private AlignmentMode mode = AlignmentMode.overlap_trim;
    private static DistanceModel dist = new IdentityDistanceModel(true);

    public TaxaSimilarityMain(List<String> selectedRanks) {
        for (String r : selectedRanks) {
            this.ranks.add(r.toLowerCase());
        }
        for (String rank : ranks) {
            sabCoutMap.put(rank.toLowerCase(), new long[BINSIZE]);
        }

    }

    public static List<String> readRanks(String rankFile) throws IOException {
        List<String> ranks = new ArrayList();
        BufferedReader reader = new BufferedReader(new FileReader(new File(rankFile)));
        String line = null;
        while ((line = reader.readLine()) != null) {
            ranks.add(line.trim());
        }
        return ranks;
    }

    public HashMap<String, HierarchyTree> getAncestorNodes(HierarchyTree root, String seqName,
            List<String> ancestors) {
        HashMap<String, HierarchyTree> ancestorNodes = new HashMap<String, HierarchyTree>();
        if (!ancestors.get(0).equals(root.getName())) {
            throw new IllegalArgumentException(
                    "Sequence " + seqName + " does not have the same root taxon" + root.getName());
        }
        ancestorNodes.put(root.getTaxonomy().getHierLevel(), root);
        HierarchyTree curParent = root;
        for (int i = 1; i < ancestors.size(); i++) {

            HierarchyTree node = curParent.getSubclassbyName(ancestors.get(i));
            if (node == null) {
                throw new IllegalArgumentException(
                        "Sequence " + seqName + " cannot find ancestor node: " + ancestors.get(i));
            }
            ancestorNodes.put(node.getTaxonomy().getHierLevel().toLowerCase(), node);
            curParent = node;
        }
        return ancestorNodes;
    }

    public void calSabSimilarity(String taxonFile, String trainSeqFile, String testSeqFile) throws IOException {
        TreeFactory factory = new TreeFactory(new FileReader(taxonFile));
        factory.buildTree();
        // get the lineage of the trainSeqFile  
        LineageSequenceParser trainParser = new LineageSequenceParser(new File(trainSeqFile));
        HashMap<String, List<String>> lineageMap = new HashMap<String, List<String>>();
        while (trainParser.hasNext()) {
            LineageSequence seq = (LineageSequence) trainParser.next();
            lineageMap.put(seq.getSeqName(), seq.getAncestors());

        }
        trainParser.close();
        NuclSeqMatch sabCal = new NuclSeqMatch(trainSeqFile);
        LineageSequenceParser parser = new LineageSequenceParser(new File(testSeqFile));

        int count = 0;
        while (parser.hasNext()) {
            LineageSequence seq = (LineageSequence) parser.next();
            HashMap<String, HierarchyTree> queryAncestorNodes = getAncestorNodes(factory.getRoot(),
                    seq.getSeqName(), seq.getAncestors());
            TreeSet<KmerMatchCore.BestMatch> matchResults = sabCal.findAllMatches(seq);

            short withinLowestRankSab = -1;
            short diffLowestRankSab = -1;
            String bestDiffLowestRankMatch = null;
            for (KmerMatchCore.BestMatch match : matchResults) {
                if (match.getBestMatch().getSeqName().equals(seq.getSeqName()))
                    continue;
                short sab = (short) (Math.round(100 * match.getSab()));
                HashMap<String, HierarchyTree> matchAncestorNodes = getAncestorNodes(factory.getRoot(),
                        match.getBestMatch().getSeqName(), lineageMap.get(match.getBestMatch().getSeqName()));
                boolean withinTaxon = false;
                for (int i = ranks.size() - 1; i >= 0; i--) {
                    HierarchyTree queryTaxon = queryAncestorNodes.get(ranks.get(i));
                    HierarchyTree matchTaxon = matchAncestorNodes.get(ranks.get(i));
                    if (queryTaxon != null && matchTaxon != null) {
                        if (queryTaxon.getName().equals(matchTaxon.getName())) {
                            if (!withinTaxon) { // if the query and match are not in the same child taxon, add sab to the current taxon
                                (sabCoutMap.get(ranks.get(i)))[sab]++;
                            }
                            withinTaxon = true;
                        } else {
                            withinTaxon = false;
                        }
                    }

                }

                // find within or different lowest level rank sab score, be either species or genus or any rank
                HierarchyTree speciesQueryTaxon = queryAncestorNodes.get(ranks.get(ranks.size() - 1));
                HierarchyTree speciesMatchTaxon = matchAncestorNodes.get(ranks.get(ranks.size() - 1));

                if (speciesQueryTaxon != null && speciesMatchTaxon != null
                        && speciesQueryTaxon.getName().equals(speciesMatchTaxon.getName())) {
                    withinLowestRankSab = sab >= withinLowestRankSab ? sab : withinLowestRankSab;
                } else {

                    if (sab >= diffLowestRankSab) {
                        bestDiffLowestRankMatch = match.getBestMatch().getSeqName();
                        diffLowestRankSab = sab;
                    }
                }
            }
            if (withinLowestRankSab > 0) {
                withinLowestRankSabSet.add(withinLowestRankSab);
            }
            if (diffLowestRankSab > 0) {
                diffLowestRankSabSet.add(diffLowestRankSab);
            }
            //System.out.println(seq.getSeqName() + "\t" + withinLowestRankSab + "\t" + diffLowestRankSab );
            count++;
            if (count % 100 == 0) {
                System.out.println(count);
            }
        }
        parser.close();

    }

    public void calPairwiseSimilaritye(String taxonFile, String trainSeqFile, String testSeqFile)
            throws IOException, OverlapCheckFailedException {
        TreeFactory factory = new TreeFactory(new FileReader(taxonFile));
        factory.buildTree();
        // get the lineage of the trainSeqFile  
        LineageSequenceParser trainParser = new LineageSequenceParser(new File(trainSeqFile));
        ArrayList<LineageSequence> trainSeqList = new ArrayList<LineageSequence>();
        while (trainParser.hasNext()) {
            LineageSequence seq = (LineageSequence) trainParser.next();
            trainSeqList.add(seq);
        }
        trainParser.close();
        LineageSequenceParser parser = new LineageSequenceParser(new File(testSeqFile));

        while (parser.hasNext()) {
            LineageSequence seq = (LineageSequence) parser.next();
            HashMap<String, HierarchyTree> queryAncestorNodes = getAncestorNodes(factory.getRoot(),
                    seq.getSeqName(), seq.getAncestors());

            for (LineageSequence trainSeq : trainSeqList) {
                if (trainSeq.getSeqName().equals(seq.getSeqName()))
                    continue;

                HashMap<String, HierarchyTree> matchAncestorNodes = getAncestorNodes(factory.getRoot(),
                        trainSeq.getSeqName(), trainSeq.getAncestors());
                boolean withinTaxon = false;
                String lowestCommonRank = null;
                for (int i = ranks.size() - 1; i >= 0; i--) {
                    HierarchyTree queryTaxon = queryAncestorNodes.get(ranks.get(i));
                    HierarchyTree matchTaxon = matchAncestorNodes.get(ranks.get(i));
                    if (queryTaxon != null && matchTaxon != null) {
                        if (queryTaxon.getName().equals(matchTaxon.getName())) {
                            if (!withinTaxon) { // if the query and match are not in the same child taxon, add sab to the current taxon
                                lowestCommonRank = ranks.get(i);
                                //(sabCoutMap.get(ranks.get(i)))[sab]++; 
                            }
                            withinTaxon = true;
                        } else {
                            withinTaxon = false;
                        }
                    }
                }

                if (lowestCommonRank == null) { // not the rank we care
                    continue;
                }

                // we need to use overlap_trim mode and calculate distance as metric to count insertions, deletions and mismatches.
                PairwiseAlignment result = PairwiseAligner.align(seq.getSeqString().replaceAll("U", "T"),
                        trainSeq.getSeqString().replaceAll("U", "T"), scoringMatrix, mode);
                short sab = (short) (100 - 100 * dist.getDistance(result.getAlignedSeqj().getBytes(),
                        result.getAlignedSeqi().getBytes(), 0));
                sabCoutMap.get(lowestCommonRank)[sab]++;

            }
        }
        parser.close();

    }

    public void createPlot(String plotTitle, File outdir) throws IOException {
        XYSeriesCollection dataset = new XYSeriesCollection();
        DefaultBoxAndWhiskerCategoryDataset scatterDataset = new DefaultBoxAndWhiskerCategoryDataset();

        PrintStream boxchart_dataStream = new PrintStream(new File(outdir, plotTitle + ".boxchart.txt"));

        boxchart_dataStream.println(
                "#\tkmer" + "\trank" + "\t" + "max" + "\t" + "avg" + "\t" + "min" + "\t" + "Q1" + "\t" + "median"
                        + "\t" + "Q3" + "\t" + "98Pct" + "\t" + "2Pct" + "\t" + "comparisons" + "\t" + "sum");
        for (int i = 0; i < ranks.size(); i++) {
            long[] countArray = sabCoutMap.get(ranks.get(i));
            if (countArray == null)
                continue;

            double sum = 0.0;
            int max = 0;
            int min = 100;
            double mean = 0;
            int Q1 = -1;
            int median = -1;
            int Q3 = -1;
            int pct_98 = -1;
            int pct_2 = -1;
            long comparisons = 0;
            int minOutlier = 0; // we don't care about the outliers
            int maxOutlier = 0; //

            XYSeries series = new XYSeries(ranks.get(i));

            for (int c = 0; c < countArray.length; c++) {
                if (countArray[c] == 0)
                    continue;
                comparisons += countArray[c];
                sum += countArray[c] * c;
                if (c < min) {
                    min = c;
                }
                if (c > max) {
                    max = c;
                }
            }

            // create series
            double cum = 0;
            for (int c = 0; c < countArray.length; c++) {
                if (countArray[c] == 0)
                    continue;
                cum += countArray[c];
                int pct = (int) Math.floor(100 * cum / comparisons);
                series.add(c, pct);

                if (pct_2 == -1 && pct >= 5) {
                    pct_2 = c;
                }
                if (Q3 == -1 && pct >= 25) {
                    Q3 = c;
                }
                if (median == -1 && pct >= 50) {
                    median = c;
                }
                if (Q1 == -1 && pct >= 75) {
                    Q1 = c;
                }
                if (pct_98 == -1 && pct >= 98) {
                    pct_98 = c;
                }
            }
            if (!series.isEmpty()) {
                dataset.addSeries(series);

                BoxAndWhiskerItem item = new BoxAndWhiskerItem(sum / comparisons, median, Q1, Q3, pct_2, pct_98,
                        minOutlier, maxOutlier, new ArrayList());
                scatterDataset.add(item, ranks.get(i), "");

                boxchart_dataStream.println("#\t" + GoodWordIterator.getWordsize() + "\t" + ranks.get(i) + "\t"
                        + max + "\t" + format.format(sum / comparisons) + "\t" + min + "\t" + Q1 + "\t" + median
                        + "\t" + Q3 + "\t" + pct_98 + "\t" + pct_2 + "\t" + comparisons + "\t" + sum);
            }
        }
        boxchart_dataStream.close();
        Font lableFont = new Font("Helvetica", Font.BOLD, 28);

        JFreeChart chart = ChartFactory.createXYLineChart(plotTitle, "Similarity%", "Percent Comparisions", dataset,
                PlotOrientation.VERTICAL, true, true, false);
        ((XYPlot) chart.getPlot()).getRenderer().setStroke(new BasicStroke(2.0f));
        chart.getLegend().setItemFont(new Font("Helvetica", Font.BOLD, 24));
        chart.getTitle().setFont(lableFont);
        ((XYPlot) chart.getPlot()).getDomainAxis().setLabelFont(lableFont);
        ((XYPlot) chart.getPlot()).getDomainAxis().setTickLabelFont(lableFont);
        ValueAxis rangeAxis = ((XYPlot) chart.getPlot()).getRangeAxis();
        rangeAxis.setRange(0, 100);
        rangeAxis.setTickLabelFont(lableFont);
        rangeAxis.setLabelFont(lableFont);
        ((NumberAxis) rangeAxis).setTickUnit(new NumberTickUnit(5));
        ChartUtilities.writeScaledChartAsPNG(new PrintStream(new File(outdir, plotTitle + ".linechart.png")), chart,
                800, 1000, 3, 3);

        BoxPlotUtils.createBoxplot(scatterDataset, new PrintStream(new File(outdir, plotTitle + ".boxchart.png")),
                plotTitle, "Rank", "Similarity%", lableFont);

    }

    /**
     * This calculates the average similarity (Sab score or pairwise alignment) between taxa at given ranks and plot the box and whisker plot and accumulation curve. 
     * The distances associate to a given rank contains the distances between different child taxa. It does not include the distances within the same child taxa.
     * For example, if a query and it's closest match are from the same genus, the distance value is added to that genus.
     * If there are from different genera but the same family, the distance value is added to that family, etc.
     * @param args
     * @throws IOException 
     */
    public static void main(String[] args) throws IOException, OverlapCheckFailedException {
        String usage = "Usage: taxonfile trainset.fasta query.fasta outdir kmersize rankFile sab|pw \n"
                + "  This program calculates the average similarity (Sab score, or pairwise alignment) within taxa\n"
                + "  and plot the box and whisker plot and accumulation curve plot. \n"
                + "  rankFile: a file contains a list of ranks to be calculated and plotted. One rank per line, no particular order required. \n"
                + "  Note pw is extremely slower, recommended only for lower ranks such as species, genus and family. ";

        if (args.length != 7) {
            System.err.println(usage);
            System.exit(1);
        }
        List<String> ranks = readRanks(args[5]);
        File outdir = new File(args[3]);
        if (!outdir.isDirectory()) {
            System.err.println("outdir must be a directory");
            System.exit(1);
        }
        int kmer = Integer.parseInt(args[4]);
        GoodWordIterator.setWordSize(kmer);
        TaxaSimilarityMain theObj = new TaxaSimilarityMain(ranks);

        String plotTitle = new File(args[2]).getName();
        int index = plotTitle.indexOf(".");
        if (index != -1) {
            plotTitle = plotTitle.substring(0, index);
        }
        if (args[6].equalsIgnoreCase("sab")) {
            theObj.calSabSimilarity(args[0], args[1], args[2]);
        } else {
            theObj.calPairwiseSimilaritye(args[0], args[1], args[2]);
        }

        theObj.createPlot(plotTitle, outdir);

    }

}