com.davidsoergel.stats.DistributionXYSeries.java Source code

Java tutorial

Introduction

Here is the source code for com.davidsoergel.stats.DistributionXYSeries.java

Source

/*
 * Copyright (c) 2006-2013  David Soergel  <dev@davidsoergel.com>
 * Licensed under the Apache License, Version 2.0
 * http://www.apache.org/licenses/LICENSE-2.0
 */

package com.davidsoergel.stats;

import com.davidsoergel.dsutils.DSArrayUtils;
import com.davidsoergel.dsutils.math.MersenneTwisterFast;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
import org.apache.log4j.Logger;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

/**
 * Stores a set of y values for each x value, allowing computation of mean, standard deviation, and quantiles for each
 * x.
 * <p/>
 * The reason we don't extend SimpleXYSeries is to avoid redundant storage of the points
 *
 * @author <a href="mailto:dev@davidsoergel.com">David Soergel</a>
 * @version $Id$
 */
public class DistributionXYSeries //extends SimpleXYSeries
{
    private static final Logger logger = Logger.getLogger(DistributionXYSeries.class);

    // need to store the sorted x's redundantly because there is no KeySortedMultimap
    SortedSet<Double> keys = new TreeSet<Double>();
    Map<Double, Multiset<Double>> yValsPerX = new HashMap<Double, Multiset<Double>>();
    //   Map<Double, Set<Double>> yValsPerX = new HashMap<Double, Set<Double>>();

    protected double xMin = Double.POSITIVE_INFINITY;
    protected double xMax = Double.NEGATIVE_INFINITY;
    private double yMin = Double.POSITIVE_INFINITY;
    private double yMax = Double.NEGATIVE_INFINITY;

    // --------------------- GETTER / SETTER METHODS ---------------------

    public double getXMax() {
        return xMax;
    }

    public double getXMin() {
        return xMin;
    }

    public double getYMin() {

        return yMin;
    }

    public double getYMax() {
        return yMax;
    }

    //@Override
    public void addPoint(double x, double y) //throws StatsException
    {
        //   super.addPoint(x, y);
        keys.add(x);
        getYMultiset(x).add(y);
        updateXBounds(x);
        updateYBounds(y);
    }

    private void updateXBounds(double x) {
        if (x < xMin) {
            xMin = x;
        }
        if (x > xMax) {
            xMax = x;
        }
    }

    private void updateYBounds(double y) {
        if (y < yMin) {
            yMin = y;
        }
        if (y > yMax) {
            yMax = y;
        }
    }

    public void addPoints(double x, Collection<Double> ys)// throws StatsException
    {
        //   super.addPoint(x, y);
        keys.add(x);
        getYMultiset(x).addAll(ys);
        updateXBounds(x);
        for (Double y : ys) {
            updateYBounds(y);
        }
    }

    private Multiset<Double> getYMultiset(double x) {
        Multiset<Double> result = yValsPerX.get(x);

        if (result == null) {
            result = HashMultiset.create();
            yValsPerX.put(x, result);
        }
        return result;
    }

    // support error bars
    public Set<Double> uniqueXValues() {
        return yValsPerX.keySet();
    }

    public double meanYAtX(double x) {
        return DSArrayUtils.mean(getYMultiset(x));
    }

    public double stddevYAtX(double x, double mean) {
        return DSArrayUtils.stddev(getYMultiset(x), mean);
    }

    public double meanYAtXCumulative(double xTop) {
        return DSArrayUtils.mean(getYList(0, xTop));
    }

    public double stddevYAtXCumulative(double xTop, double mean) {
        return DSArrayUtils.stddev(getYList(0, xTop), mean);
    }

    public BinnedXYSeries makeBinnedXYSeries(double halfBinWidth, double binStep) {

        try {
            BinnedXYSeries result = new BinnedXYSeries(this);

            //special case for binStep == 0: use current x values as centers

            if (binStep == 0) {
                //double halfBinWidth = binWidth / 2.0;
                /*
                            for (SimpleXYSeries.XYPoint p : points)
                               {
                    
                               double bottom = p.x - halfBinWidth;
                               double top = p.x + halfBinWidth;
                    
                               result.addBin(bottom, top, getYArray(bottom, top));
                               }
                */
                for (double x : keys) {
                    //double bottom = x - halfBinWidth;
                    //double top = x + halfBinWidth;

                    result.addBin(x, halfBinWidth);
                }
            } else {
                //   double xMin = keys.first();
                //   double xMax = keys.last();
                for (double center = xMin; center < xMax; center += binStep) {
                    result.addBin(center, halfBinWidth);//, getYArray(bottom, top));
                }
            }
            /*
            for (Double x : keys)
               {
               result.addPointsToBins(x, yValsPerX.get(x));  // forget the original x value; collapse all the points to the bin center
               }
               */
            return result;
        } catch (DistributionException e) {
            logger.error("Error", e);
            throw new Error(e);
        }
    }

    public BinnedXYSeries binXTiles(int numBins) {
        //   double xMin = keys.first();
        //   double xMax = keys.last();
        double binWidth = ((xMax - xMin) / numBins);
        return makeBinnedXYSeries(binWidth / 2.0, binWidth);
    }

    public BinnedXYSeries binXQuantiles(int numQuantiles) {
        try {
            BinnedXYSeries result = new BinnedXYSeries(this);

            Multiset<Double> allXValues = HashMultiset.create();
            for (Map.Entry<Double, Multiset<Double>> entry : yValsPerX.entrySet()) {
                allXValues.add(entry.getKey(), entry.getValue().size());
            }

            EqualWeightHistogram1D theBaseHistogram = new EqualWeightHistogram1D(numQuantiles,
                    DSArrayUtils.toPrimitiveDoubleArray(allXValues));

            int numBins = theBaseHistogram.getBins();
            for (int i = 0; i < numBins; i++) {
                //   double bottom = theBaseHistogram.bottomOfBin(i);
                //   double top = theBaseHistogram.topOfBin(i);
                double center = theBaseHistogram.centerOfBin(i);
                double halfBinWidth = theBaseHistogram.halfWidthOfBin(i);
                result.addBin(center, halfBinWidth);
                /*      for (Double x : keys.tailSet(bottom).headSet(top))
                   {
                   result.addPoints(center, yValsPerX.get(x));  // forget the original x value; collapse all the points to the bin center
                   }*/
            }
            return result;
        } catch (DistributionException e) {
            logger.error("Error", e);
            throw new Error(e);
        } catch (StatsException e) {
            logger.error("Error", e);
            throw new Error(e);
        }
    }

    //   public double[] getYArray(double bottom, double top)
    //      {
    //      return DSArrayUtils.toPrimitiveArray(getYList(bottom,top));
    /*
       List<Double> result = new ArrayList<Double>();
       for (Double x : keys.tailSet(bottom).headSet(top))
     {
     result.addAll(yValsPerX.get(x));
     }
       return DSArrayUtils.toPrimitiveArray(result);
       */
    //      }

    /**
     * Returns a list of all Y values associated with X values in the given range, including the bottom but not including
     * the top point
     *
     * @param bottom the lowest X value to consider, inclusive
     * @param top    the highest X value to consider, exclusive
     * @return
     */
    public List<Double> getYList(double bottom, double top) {
        List<Double> result = new ArrayList<Double>();
        try {
            //bottom = Math.max(bottom, keys.first());  // not necessary

            // here is an annoyance: the headSet(top) call can't ever include keys.last(),
            // because either top == keys.last() in which case it's an open interval,
            // or top > keys.last(), in which case headSet throws an exception because
            // it's operating on a tailSet and top is outside the bounds.

            double last = keys.last();
            double topTrim = Math.min(top, last); // avoid the exception...

            for (Double x : keys.tailSet(bottom).headSet(topTrim)) {
                // check that a multiset iterator returns the duplicates
                int i = result.size();
                Multiset<Double> multiset = yValsPerX.get(x);
                result.addAll(multiset);
                assert result.size() == i + multiset.size();
            }

            if (top > last) // top is exclusive
            {
                result.addAll(yValsPerX.get(last)); // but do include the point
            }
        } catch (IllegalArgumentException e) {
            // there are no points in the requested range, so we just return an empty list
        }
        return result;
    }

    /*
       public Set<Double> getUniqueYForXRange(double bottom, double top)
          {
          // PERF when bins overlap, the points in the intersection are added to the result set repeatedly
        
          Set<Double> result = new HashSet<Double>();
          Map<Double, Collection<Double>> map = yValsPerX.asMap();
          for (Map.Entry<Double, Collection<Double>> entry : map.entrySet())
     {
     Double key = entry.getKey();
     if (key >= bottom && key < top)
        {
        result.addAll(entry.getValue());
        }
     }
          return result;
          }
    */

    public int size() {
        int result = 0;
        for (Multiset<Double> doubles : yValsPerX.values()) {
            result += doubles.size();
        }
        return result;
    }

    public SimpleXYSeries asSimpleXYSeries() throws StatsException {
        SimpleXYSeries result = new SimpleXYSeries();
        for (Double x : keys) {
            for (Double y : getYMultiset(x)) {
                result.addPoint(x, y);
            }
        }
        return result;
    }

    LinearRegression regression = null;
    double correlation;

    public double pearsonCorrelation() throws StatsException {
        if (regression == null) {
            computeRegression();
        }
        return correlation;
    }

    public double R2() throws StatsException {
        if (regression == null) {
            computeRegression();
        }
        logger.debug("Correlations: " + regression.R2 + ", " + correlation);
        return regression.R2;
    }

    public double regressionM() throws StatsException {
        if (regression == null) {
            computeRegression();
        }
        return regression.slope;
    }

    public double regressionB() throws StatsException {
        if (regression == null) {
            computeRegression();
        }
        return regression.intercept;
    }

    private void computeRegression() throws StatsException {
        double[] xs = new double[size()];
        double[] ys = new double[size()];
        int i = 0;
        for (Double x : keys) {
            for (Double y : getYMultiset(x)) {
                xs[i] = x;
                ys[i] = y;
                i++;
            }
        }

        regression = new LinearRegression(xs, ys);
        correlation = PearsonCorrelation.computeCorrelationCoefficient(xs, ys);
    }

    public void addAllVsAllPoints(List<Double> xValues, List<Double> yValues, boolean xlog, boolean ylog,
            double xnoise, double ynoise) {
        for (double dx : xValues) {
            //   double dx = dxv.getValue();
            if (xlog) {
                dx = Math.log10(dx);
            }

            for (double dy : yValues) {
                //double dy = dyv.getValue();

                if (ylog) {
                    dy = Math.log10(dy);
                }
                double dxn = dx;
                if (xnoise != 0) {
                    dxn += xnoise * (MersenneTwisterFast.gaussian() - 0.5);
                }

                if (ynoise != 0) {
                    dy += ynoise * (MersenneTwisterFast.gaussian() - 0.5);
                }

                //try
                //   {
                addPoint(dxn, dy);
                //   }
                //catch (StatsException e)
                //   {
                //   logger.warn("warn", e);
                //   }
            }
        }
    }

    public void addParallelArrayPoints(List<Double> xValues, List<Double> yValues, boolean xlog, boolean ylog,
            double xnoise, double ynoise) {
        Iterator<Double> yIterator = yValues.iterator();
        for (double dx : xValues) {
            //   double dx = dxv.getValue();
            if (xlog) {
                dx = Math.log10(dx);
            }

            if (xnoise != 0) {
                dx += xnoise * (MersenneTwisterFast.gaussian() - 0.5);
            }
            double dy = yIterator.next();

            //double dy = dyv.getValue();

            if (ylog) {
                dy = Math.log10(dy);
            }

            if (ynoise != 0) {
                dy += ynoise * (MersenneTwisterFast.gaussian() - 0.5);
            }

            //try
            //   {
            addPoint(dx, dy);
            //   }
            //catch (StatsException e)
            //   {
            //   logger.warn("warn", e);
            //   }
        }
    }

    public void addParallelArrayPoints(List<Double> xValues, List<Double> yValues, boolean xlog, boolean ylog,
            double xnoise, double ynoise, List<Double> maskValues, double maskMin, double maskMax,
            List<Double> mask2Values, double mask2Min, double mask2Max) {
        Iterator<Double> yIterator = yValues.iterator();
        Iterator<Double> maskIterator = maskValues.iterator();
        Iterator<Double> mask2Iterator = mask2Values.iterator();
        for (double dx : xValues) {
            double dMask = maskIterator.next();
            double dMask2 = mask2Iterator.next();
            double dy = yIterator.next();

            if (dMask >= maskMin && dMask <= maskMax && dMask2 >= mask2Min && dMask2 <= mask2Max) {
                //   double dx = dxv.getValue();
                if (xlog) {
                    dx = Math.log10(dx);
                }

                if (xnoise != 0) {
                    dx += xnoise * (MersenneTwisterFast.gaussian() - 0.5);
                }

                //double dy = dyv.getValue();

                if (ylog) {
                    dy = Math.log10(dy);
                }

                if (ynoise != 0) {
                    dy += ynoise * (MersenneTwisterFast.gaussian() - 0.5);
                }

                //try
                //   {
                addPoint(dx, dy);
                //   }
                //catch (StatsException e)
                //   {
                //   logger.warn("warn", e);
                //   }
            } else {
                addPoint(dx, Double.NaN);
            }
        }
    }

    public void addParallelArrayPoints(List<Double> xValues, List<Double> yValues, boolean xlog, boolean ylog,
            double xnoise, double ynoise, List<Double> maskValues, double maskMin, double maskMax) {
        Iterator<Double> yIterator = yValues.iterator();
        Iterator<Double> maskIterator = maskValues.iterator();
        for (double dx : xValues) {
            double dMask = maskIterator.next();
            double dy = yIterator.next();

            if (dMask >= maskMin && dMask <= maskMax) {
                //   double dx = dxv.getValue();
                if (xlog) {
                    dx = Math.log10(dx);
                }

                if (xnoise != 0) {
                    dx += xnoise * (MersenneTwisterFast.gaussian() - 0.5);
                }

                //double dy = dyv.getValue();

                if (ylog) {
                    dy = Math.log10(dy);
                }

                if (ynoise != 0) {
                    dy += ynoise * (MersenneTwisterFast.gaussian() - 0.5);
                }

                //try
                //   {
                addPoint(dx, dy);
                //   }
                //catch (StatsException e)
                //   {
                //   logger.warn("warn", e);
                //   }
            } else {
                addPoint(dx, Double.NaN);
            }
        }
    }
}