org.micromanager.saim.fit.Fitter.java Source code

Java tutorial

Introduction

Here is the source code for org.micromanager.saim.fit.Fitter.java

Source

///////////////////////////////////////////////////////////////////////////////
//FILE:          Fitter.java
//PROJECT:       Micro-Manager 
//SUBSYSTEM:     ASIdiSPIM plugin
//-----------------------------------------------------------------------------
//
// AUTHOR:       Nico Stuurman, Jon Daniels
//
// COPYRIGHT:    University of California, San Francisco, & ASI, 2015, 2016
//
// LICENSE:      This file is distributed under the BSD license.
//               License text is included with the source distribution.
//
//               This file is distributed in the hope that it will be useful,
//               but WITHOUT ANY WARRANTY; without even the implied warranty
//               of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
//
//               IN NO EVENT SHALL THE COPYRIGHT OWNER OR
//               CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
//               INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES.

package org.micromanager.saim.fit;

import org.apache.commons.math3.analysis.function.Gaussian;
import org.apache.commons.math3.analysis.polynomials.PolynomialFunction;
import org.apache.commons.math3.analysis.solvers.BracketingNthOrderBrentSolver;
import org.apache.commons.math3.analysis.solvers.UnivariateSolver;
import org.apache.commons.math3.fitting.PolynomialCurveFitter;
import org.apache.commons.math3.fitting.WeightedObservedPoints;
import org.jfree.data.xy.XYSeries;

/**
 *
 * @author nico
 */
public class Fitter {

    private static final String NOFIT = "No fit";
    private static final String POL1 = "Polynomial 1";
    private static final String POL2 = "Polynomial 2";
    private static final String POL3 = "Polynomial 3";
    private static final String GAUSSIAN = "Gaussian";

    public static enum FunctionType {
        NoFit, Pol1, Pol2, Pol3, Gaussian
    };

    public static enum WeightMethod {
        Equal, Linear, Quadratic, Top50Linear, Top80Linear
    }

    /**
     * Utility to facilitate fitting data plotted in JFreeChart
     * Provide data in JFReeChart format (XYSeries), and retrieve univariate
     * function parameters that best fit (using least squares) the data. All data
     * points will be weighted equally.
     * 
     * This version provides equal weighting of all datapoints
     * 
     * @param data xy series in JFReeChart format
     * @param type one of the Fitter.FunctionType predefined functions
     * @param guess initial guess for the fit.  The number and meaning of these
         parameters depends on the FunctionType.  Implemented:
         Gaussian: 0: Normalization, 1: Mean 2: Sigma
        
     * @return array with parameters, whose meaning depends on the FunctionType.
     *          Use the function getXYSeries to retrieve the XYDataset predicted 
     *          by this fit
     */
    public static double[] fit(XYSeries data, FunctionType type, double[] guess) {
        return fit(data, type, guess, WeightMethod.Equal);
    }

    /**
     * Utility to facilitate fitting data plotted in JFreeChart
     * Provide data in JFReeChart format (XYSeries), and retrieve univariate
     * function parameters that best fit (using least squares) the data. All data
     * points will be weighted equally.
     * 
     * Various weightmethods are implemented and can be selected using the 
     * weightMethods parameter.
     * 
     * @param data xy series in JFReeChart format
     * @param type one of the Fitter.FunctionType predefined functions
     * @param guess initial guess for the fit.  The number and meaning of these
         parameters depends on the FunctionType.  Implemented:
         Gaussian: 0: Normalization, 1: Mean 2: Sigma
     * @param weightMethod One of the methods in the WeightMethod enum
        
     * @return array with parameters, whose meaning depends on the FunctionType.
     *          Use the function getXYSeries to retrieve the XYDataset predicted 
     *          by this fit
     */
    public static double[] fit(XYSeries data, FunctionType type, double[] guess, WeightMethod weightMethod) {

        if (type == FunctionType.NoFit) {
            return null;
        }
        // create the commons math data object from the JFreeChart data object
        final WeightedObservedPoints obs = new WeightedObservedPoints();
        // range is used in weigt calculations
        double range = data.getMaxY() - data.getMinY();
        for (int i = 0; i < data.getItemCount(); i++) {
            // add weight based on y intensity and selected weight method
            double weight = 1.0; // used in Equal method
            if (weightMethod != WeightMethod.Equal) {
                double valueMinusMin = data.getY(i).doubleValue() - data.getMinY();
                weight = valueMinusMin / range;
                switch (weightMethod) {
                case Equal:
                    break; // weight is already linear
                case Quadratic:
                    weight *= weight;
                    break;
                case Top50Linear:
                    if (valueMinusMin < (0.5 * range))
                        weight = 0.0;
                    break;
                case Top80Linear:
                    if (valueMinusMin < (0.8 * range))
                        weight = 0.0;
                    break;
                }
            }

            obs.add(weight, data.getX(i).doubleValue(), data.getY(i).doubleValue());
        }

        // Carry out the actual fit
        double[] result = null;
        switch (type) {
        case Pol1:
            final PolynomialCurveFitter fitter1 = PolynomialCurveFitter.create(1);
            result = fitter1.fit(obs.toList());
            break;
        case Pol2:
            final PolynomialCurveFitter fitter2 = PolynomialCurveFitter.create(2);
            result = fitter2.fit(obs.toList());
            break;
        case Pol3:
            final PolynomialCurveFitter fitter3 = PolynomialCurveFitter.create(3);
            result = fitter3.fit(obs.toList());
            break;
        case Gaussian:
            GaussianWithOffsetCurveFitter gf = GaussianWithOffsetCurveFitter.create();
            gf = gf.withMaxIterations(50);
            if (guess != null) {
                gf.withStartPoint(guess);
            }
            result = gf.fit(obs.toList());
        }

        return result;
    }

    /**
     * Given a JFreeChart dataset and a commons math function, return a JFreeChart
     * dataset in which the original x values are now accompanied by the y values
     * predicted by the function
     * 
     * @param data input JFreeChart data set
     * @param type one of the Fitter.FunctionType predefined functions
     * @param parms parameters describing the function.  These need to match the
     *             selected function or an IllegalArgumentEception will be thrown
     * 
     * @return JFreeChart dataset with original x values and fitted y values.
     */
    public static XYSeries getFittedSeries(XYSeries data, FunctionType type, double[] parms) {

        XYSeries result = new XYSeries((String) data.getKey() + "-Fit", false, true);
        double minRange = data.getMinX();
        double maxRange = data.getMaxX();
        double xStep = (maxRange - minRange) / (data.getItemCount() * 10);
        switch (type) {
        case NoFit: {
            try {
                XYSeries resCopy = data.createCopy(0, data.getItemCount() - 1);
                return resCopy;
            } catch (CloneNotSupportedException ex) {
                return null;
            }
        }
        case Pol1:
        case Pol2:
        case Pol3:
            checkParms(type, parms);
            PolynomialFunction polFunction = new PolynomialFunction(parms);
            for (int i = 0; i < data.getItemCount() * 10; i++) {
                double x = minRange + i * xStep;
                double y = polFunction.value(x);
                result.add(x, y);
            }
            break;
        case Gaussian:
            checkParms(type, parms);
            Gaussian.Parametric gf = new Gaussian.Parametric();
            for (int i = 0; i < data.getItemCount() * 10; i++) {
                double x = minRange + i * xStep;
                double[] gparms = new double[3];
                System.arraycopy(parms, 0, gparms, 0, 3);
                double y = gf.value(x, gparms) + parms[3];
                result.add(x, y);
            }
            break;
        }

        return result;
    }

    /**
     * Finds the x value corresponding to the maximum function value within the 
     * range of the provided data set.
     * 
     * @param type one of the Fitter.FunctionType predefined functions
     * @param parms parameters describing the function.  These need to match the
     *             selected function or an IllegalArgumentEception will be thrown
     * @param data JFreeChart series, used to bracket the range in which the 
     *             maximum will be found
     * 
     * @return x value corresponding to the maximum function value
     */
    public static double getXofMaxY(XYSeries data, FunctionType type, double[] parms) {
        double xAtMax = 0.0;
        double minX = data.getMinX();
        double maxX = data.getMaxX();
        switch (type) {
        case NoFit:
            //  find the position in data with the highest y value
            double highestScore = data.getY(0).doubleValue();
            int highestIndex = 0;
            for (int i = 1; i < data.getItemCount(); i++) {
                double newVal = data.getY(i).doubleValue();
                if (newVal > highestScore) {
                    highestScore = newVal;
                    highestIndex = i;
                }
            }
            return data.getX(highestIndex).doubleValue();
        case Pol1:
        case Pol2:
        case Pol3:
            checkParms(type, parms);
            PolynomialFunction derivativePolFunction = (new PolynomialFunction(parms)).polynomialDerivative();

            final double relativeAccuracy = 1.0e-12;
            final double absoluteAccuracy = 1.0e-8;
            final int maxOrder = 5;
            UnivariateSolver solver = new BracketingNthOrderBrentSolver(relativeAccuracy, absoluteAccuracy,
                    maxOrder);
            xAtMax = solver.solve(100, derivativePolFunction, minX, maxX);
            break;
        case Gaussian:
            // for a Gaussian we can take the mean and be sure it is the maximum
            // note that this may be outside our range of X values, but 
            // this will be caught by our sanity checks below
            xAtMax = parms[1];
        }

        // sanity checks
        if (xAtMax > maxX)
            xAtMax = maxX;
        if (xAtMax < minX)
            xAtMax = minX;

        return xAtMax;
    }

    /**
     * Find the index in the data series with an x value closest to the given
     * searchValue
     * 
     * @param data data in XYSeries format
     * @param searchValue x value that we try to get close to
     * @return index into data with x value closest to searhValue
     */
    public static int getIndex(XYSeries data, double searchValue) {
        int index = 0;
        double diff = dataDiff(data.getX(0), searchValue);
        for (int i = 1; i < data.getItemCount(); i++) {
            double newVal = dataDiff(data.getX(i), searchValue);
            if (newVal < diff) {
                diff = newVal;
                index = i;
            }
        }
        return index;
    }

    /**
     * helper function for getIndex
     * 
     * @param num
     * @param val
     * @return 
     */
    private static double dataDiff(Number num, double val) {
        double diff = num.doubleValue() - val;
        return Math.sqrt(diff * diff);
    }

    /**
     * Calculates a measure for the goodness of fit as defined here:
     * http://en.wikipedia.org/wiki/Coefficient_of_determination
     * R^2 = 1 - (SSres/SStot)
     * where
     *    SSres = SUM(i) (yi - fi)^2
     * end
     *    SStot = SUM(i) (yi - yavg)^2
     * 
     * @param data input data (raw data that were fitted
     * @param type function type used for fitting
     * @param parms function parameters derived in the fit
     * @return 
     */
    public static double getRSquare(XYSeries data, FunctionType type, double[] parms) {

        // calculate SStot
        double yAvg = getYAvg(data);
        double ssTot = 0.0;
        for (int i = 0; i < data.getItemCount(); i++) {
            double y = data.getY(i).doubleValue();
            ssTot += (y - yAvg) * (y - yAvg);
        }

        // calculate SSres
        double ssRes = 0.0;
        for (int i = 0; i < data.getItemCount(); i++) {
            double y = data.getY(i).doubleValue();
            double f = getFunctionValue(data.getX(i).doubleValue(), type, parms);
            ssRes += (y - f) * (y - f);

        }

        return 1.0 - (ssRes / ssTot);
    }

    /**
     * Returns the average of the ys in a XYSeries
     * @param data input data
     * @return y average
     */
    public static double getYAvg(XYSeries data) {
        double avg = 0;
        for (int i = 0; i < data.getItemCount(); i++) {
            avg += data.getY(i).doubleValue();
        }
        avg = avg / data.getItemCount();
        return avg;
    }

    /**
     * Calculate the y value for a given function and x value
     * Throws an IllegalArgumentException if the parms do not match the function
     * @param xValue xValue to be used in the function
     * @param type function type
     * @param parms function parameters (e.g., as returned from the fit function)
     * @return 
     */
    public static double getFunctionValue(double xValue, FunctionType type, double[] parms) {
        switch (type) {
        case NoFit: {
            return xValue;
        }
        case Pol1:
        case Pol2:
        case Pol3:
            checkParms(type, parms);
            PolynomialFunction polFunction = new PolynomialFunction(parms);

            return polFunction.value(xValue);
        case Gaussian:
            checkParms(type, parms);
            Gaussian.Parametric gf = new Gaussian.Parametric();
            double[] parms2 = new double[3];
            System.arraycopy(parms, 0, parms2, 0, 3);
            return gf.value(xValue, parms2) + parms[3];
        }
        return 0.0;
    }

    private static void checkParms(FunctionType type, double[] parms) {
        switch (type) {
        case Pol1:
            if (parms.length != 2) {
                throw new IllegalArgumentException("Needs a double[] of size 2");
            }
            break;
        case Pol2:
            if (parms.length != 3) {
                throw new IllegalArgumentException("Needs a double[] of size 3");
            }
            break;
        case Pol3:
            if (parms.length != 4) {
                throw new IllegalArgumentException("Needs a double[] of size 4");
            }
            break;
        case Gaussian:
            if (parms.length != 4) {
                throw new IllegalArgumentException("Needs a double[] of size 4");
            }
            break;
        }
    }

    public static String getFunctionTypeAsString(FunctionType key) {
        switch (key) {
        case NoFit:
            return NOFIT;
        case Pol1:
            return POL1;
        case Pol2:
            return POL2;
        case Pol3:
            return POL2;
        case Gaussian:
            return GAUSSIAN;
        }
        return "";
    }

    public static FunctionType getFunctionTypeAsType(String key) {
        if (key.equals(NOFIT))
            return FunctionType.NoFit;
        if (key.equals(POL1))
            return FunctionType.Pol1;
        if (key.equals(POL2))
            return FunctionType.Pol2;
        if (key.equals(POL3))
            return FunctionType.Pol3;
        if (key.equals(GAUSSIAN))
            return FunctionType.Gaussian;
        return FunctionType.NoFit;
    }

    public static String[] getFunctionTypes() {
        return new String[] { NOFIT, POL1, POL2, POL3, GAUSSIAN };
    }

}