Java tutorial
/////////////////////////////////////////////////////////////////////////////// //FILE: Fitter.java //PROJECT: Micro-Manager //SUBSYSTEM: ASIdiSPIM plugin //----------------------------------------------------------------------------- // // AUTHOR: Nico Stuurman, Jon Daniels // // COPYRIGHT: University of California, San Francisco, & ASI, 2015, 2016 // // LICENSE: This file is distributed under the BSD license. // License text is included with the source distribution. // // This file is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty // of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. // // IN NO EVENT SHALL THE COPYRIGHT OWNER OR // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES. package org.micromanager.saim.fit; import org.apache.commons.math3.analysis.function.Gaussian; import org.apache.commons.math3.analysis.polynomials.PolynomialFunction; import org.apache.commons.math3.analysis.solvers.BracketingNthOrderBrentSolver; import org.apache.commons.math3.analysis.solvers.UnivariateSolver; import org.apache.commons.math3.fitting.PolynomialCurveFitter; import org.apache.commons.math3.fitting.WeightedObservedPoints; import org.jfree.data.xy.XYSeries; /** * * @author nico */ public class Fitter { private static final String NOFIT = "No fit"; private static final String POL1 = "Polynomial 1"; private static final String POL2 = "Polynomial 2"; private static final String POL3 = "Polynomial 3"; private static final String GAUSSIAN = "Gaussian"; public static enum FunctionType { NoFit, Pol1, Pol2, Pol3, Gaussian }; public static enum WeightMethod { Equal, Linear, Quadratic, Top50Linear, Top80Linear } /** * Utility to facilitate fitting data plotted in JFreeChart * Provide data in JFReeChart format (XYSeries), and retrieve univariate * function parameters that best fit (using least squares) the data. All data * points will be weighted equally. * * This version provides equal weighting of all datapoints * * @param data xy series in JFReeChart format * @param type one of the Fitter.FunctionType predefined functions * @param guess initial guess for the fit. The number and meaning of these parameters depends on the FunctionType. Implemented: Gaussian: 0: Normalization, 1: Mean 2: Sigma * @return array with parameters, whose meaning depends on the FunctionType. * Use the function getXYSeries to retrieve the XYDataset predicted * by this fit */ public static double[] fit(XYSeries data, FunctionType type, double[] guess) { return fit(data, type, guess, WeightMethod.Equal); } /** * Utility to facilitate fitting data plotted in JFreeChart * Provide data in JFReeChart format (XYSeries), and retrieve univariate * function parameters that best fit (using least squares) the data. All data * points will be weighted equally. * * Various weightmethods are implemented and can be selected using the * weightMethods parameter. * * @param data xy series in JFReeChart format * @param type one of the Fitter.FunctionType predefined functions * @param guess initial guess for the fit. The number and meaning of these parameters depends on the FunctionType. Implemented: Gaussian: 0: Normalization, 1: Mean 2: Sigma * @param weightMethod One of the methods in the WeightMethod enum * @return array with parameters, whose meaning depends on the FunctionType. * Use the function getXYSeries to retrieve the XYDataset predicted * by this fit */ public static double[] fit(XYSeries data, FunctionType type, double[] guess, WeightMethod weightMethod) { if (type == FunctionType.NoFit) { return null; } // create the commons math data object from the JFreeChart data object final WeightedObservedPoints obs = new WeightedObservedPoints(); // range is used in weigt calculations double range = data.getMaxY() - data.getMinY(); for (int i = 0; i < data.getItemCount(); i++) { // add weight based on y intensity and selected weight method double weight = 1.0; // used in Equal method if (weightMethod != WeightMethod.Equal) { double valueMinusMin = data.getY(i).doubleValue() - data.getMinY(); weight = valueMinusMin / range; switch (weightMethod) { case Equal: break; // weight is already linear case Quadratic: weight *= weight; break; case Top50Linear: if (valueMinusMin < (0.5 * range)) weight = 0.0; break; case Top80Linear: if (valueMinusMin < (0.8 * range)) weight = 0.0; break; } } obs.add(weight, data.getX(i).doubleValue(), data.getY(i).doubleValue()); } // Carry out the actual fit double[] result = null; switch (type) { case Pol1: final PolynomialCurveFitter fitter1 = PolynomialCurveFitter.create(1); result = fitter1.fit(obs.toList()); break; case Pol2: final PolynomialCurveFitter fitter2 = PolynomialCurveFitter.create(2); result = fitter2.fit(obs.toList()); break; case Pol3: final PolynomialCurveFitter fitter3 = PolynomialCurveFitter.create(3); result = fitter3.fit(obs.toList()); break; case Gaussian: GaussianWithOffsetCurveFitter gf = GaussianWithOffsetCurveFitter.create(); gf = gf.withMaxIterations(50); if (guess != null) { gf.withStartPoint(guess); } result = gf.fit(obs.toList()); } return result; } /** * Given a JFreeChart dataset and a commons math function, return a JFreeChart * dataset in which the original x values are now accompanied by the y values * predicted by the function * * @param data input JFreeChart data set * @param type one of the Fitter.FunctionType predefined functions * @param parms parameters describing the function. These need to match the * selected function or an IllegalArgumentEception will be thrown * * @return JFreeChart dataset with original x values and fitted y values. */ public static XYSeries getFittedSeries(XYSeries data, FunctionType type, double[] parms) { XYSeries result = new XYSeries((String) data.getKey() + "-Fit", false, true); double minRange = data.getMinX(); double maxRange = data.getMaxX(); double xStep = (maxRange - minRange) / (data.getItemCount() * 10); switch (type) { case NoFit: { try { XYSeries resCopy = data.createCopy(0, data.getItemCount() - 1); return resCopy; } catch (CloneNotSupportedException ex) { return null; } } case Pol1: case Pol2: case Pol3: checkParms(type, parms); PolynomialFunction polFunction = new PolynomialFunction(parms); for (int i = 0; i < data.getItemCount() * 10; i++) { double x = minRange + i * xStep; double y = polFunction.value(x); result.add(x, y); } break; case Gaussian: checkParms(type, parms); Gaussian.Parametric gf = new Gaussian.Parametric(); for (int i = 0; i < data.getItemCount() * 10; i++) { double x = minRange + i * xStep; double[] gparms = new double[3]; System.arraycopy(parms, 0, gparms, 0, 3); double y = gf.value(x, gparms) + parms[3]; result.add(x, y); } break; } return result; } /** * Finds the x value corresponding to the maximum function value within the * range of the provided data set. * * @param type one of the Fitter.FunctionType predefined functions * @param parms parameters describing the function. These need to match the * selected function or an IllegalArgumentEception will be thrown * @param data JFreeChart series, used to bracket the range in which the * maximum will be found * * @return x value corresponding to the maximum function value */ public static double getXofMaxY(XYSeries data, FunctionType type, double[] parms) { double xAtMax = 0.0; double minX = data.getMinX(); double maxX = data.getMaxX(); switch (type) { case NoFit: // find the position in data with the highest y value double highestScore = data.getY(0).doubleValue(); int highestIndex = 0; for (int i = 1; i < data.getItemCount(); i++) { double newVal = data.getY(i).doubleValue(); if (newVal > highestScore) { highestScore = newVal; highestIndex = i; } } return data.getX(highestIndex).doubleValue(); case Pol1: case Pol2: case Pol3: checkParms(type, parms); PolynomialFunction derivativePolFunction = (new PolynomialFunction(parms)).polynomialDerivative(); final double relativeAccuracy = 1.0e-12; final double absoluteAccuracy = 1.0e-8; final int maxOrder = 5; UnivariateSolver solver = new BracketingNthOrderBrentSolver(relativeAccuracy, absoluteAccuracy, maxOrder); xAtMax = solver.solve(100, derivativePolFunction, minX, maxX); break; case Gaussian: // for a Gaussian we can take the mean and be sure it is the maximum // note that this may be outside our range of X values, but // this will be caught by our sanity checks below xAtMax = parms[1]; } // sanity checks if (xAtMax > maxX) xAtMax = maxX; if (xAtMax < minX) xAtMax = minX; return xAtMax; } /** * Find the index in the data series with an x value closest to the given * searchValue * * @param data data in XYSeries format * @param searchValue x value that we try to get close to * @return index into data with x value closest to searhValue */ public static int getIndex(XYSeries data, double searchValue) { int index = 0; double diff = dataDiff(data.getX(0), searchValue); for (int i = 1; i < data.getItemCount(); i++) { double newVal = dataDiff(data.getX(i), searchValue); if (newVal < diff) { diff = newVal; index = i; } } return index; } /** * helper function for getIndex * * @param num * @param val * @return */ private static double dataDiff(Number num, double val) { double diff = num.doubleValue() - val; return Math.sqrt(diff * diff); } /** * Calculates a measure for the goodness of fit as defined here: * http://en.wikipedia.org/wiki/Coefficient_of_determination * R^2 = 1 - (SSres/SStot) * where * SSres = SUM(i) (yi - fi)^2 * end * SStot = SUM(i) (yi - yavg)^2 * * @param data input data (raw data that were fitted * @param type function type used for fitting * @param parms function parameters derived in the fit * @return */ public static double getRSquare(XYSeries data, FunctionType type, double[] parms) { // calculate SStot double yAvg = getYAvg(data); double ssTot = 0.0; for (int i = 0; i < data.getItemCount(); i++) { double y = data.getY(i).doubleValue(); ssTot += (y - yAvg) * (y - yAvg); } // calculate SSres double ssRes = 0.0; for (int i = 0; i < data.getItemCount(); i++) { double y = data.getY(i).doubleValue(); double f = getFunctionValue(data.getX(i).doubleValue(), type, parms); ssRes += (y - f) * (y - f); } return 1.0 - (ssRes / ssTot); } /** * Returns the average of the ys in a XYSeries * @param data input data * @return y average */ public static double getYAvg(XYSeries data) { double avg = 0; for (int i = 0; i < data.getItemCount(); i++) { avg += data.getY(i).doubleValue(); } avg = avg / data.getItemCount(); return avg; } /** * Calculate the y value for a given function and x value * Throws an IllegalArgumentException if the parms do not match the function * @param xValue xValue to be used in the function * @param type function type * @param parms function parameters (e.g., as returned from the fit function) * @return */ public static double getFunctionValue(double xValue, FunctionType type, double[] parms) { switch (type) { case NoFit: { return xValue; } case Pol1: case Pol2: case Pol3: checkParms(type, parms); PolynomialFunction polFunction = new PolynomialFunction(parms); return polFunction.value(xValue); case Gaussian: checkParms(type, parms); Gaussian.Parametric gf = new Gaussian.Parametric(); double[] parms2 = new double[3]; System.arraycopy(parms, 0, parms2, 0, 3); return gf.value(xValue, parms2) + parms[3]; } return 0.0; } private static void checkParms(FunctionType type, double[] parms) { switch (type) { case Pol1: if (parms.length != 2) { throw new IllegalArgumentException("Needs a double[] of size 2"); } break; case Pol2: if (parms.length != 3) { throw new IllegalArgumentException("Needs a double[] of size 3"); } break; case Pol3: if (parms.length != 4) { throw new IllegalArgumentException("Needs a double[] of size 4"); } break; case Gaussian: if (parms.length != 4) { throw new IllegalArgumentException("Needs a double[] of size 4"); } break; } } public static String getFunctionTypeAsString(FunctionType key) { switch (key) { case NoFit: return NOFIT; case Pol1: return POL1; case Pol2: return POL2; case Pol3: return POL2; case Gaussian: return GAUSSIAN; } return ""; } public static FunctionType getFunctionTypeAsType(String key) { if (key.equals(NOFIT)) return FunctionType.NoFit; if (key.equals(POL1)) return FunctionType.Pol1; if (key.equals(POL2)) return FunctionType.Pol2; if (key.equals(POL3)) return FunctionType.Pol3; if (key.equals(GAUSSIAN)) return FunctionType.Gaussian; return FunctionType.NoFit; } public static String[] getFunctionTypes() { return new String[] { NOFIT, POL1, POL2, POL3, GAUSSIAN }; } }