uk.ac.diamond.scisoft.analysis.diffraction.FittingUtils.java Source code

Java tutorial

Introduction

Here is the source code for uk.ac.diamond.scisoft.analysis.diffraction.FittingUtils.java

Source

/*-
 * Copyright 2014 Diamond Light Source Ltd.
 *
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License v1.0
 * which accompanies this distribution, and is available at
 * http://www.eclipse.org/legal/epl-v10.html
 */

package uk.ac.diamond.scisoft.analysis.diffraction;

import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleBounds;
import org.apache.commons.math3.optim.SimplePointChecker;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateFunctionPenaltyAdapter;
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.MultiDirectionalSimplex;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer;
import org.apache.commons.math3.random.Well19937c;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Utilities for fitting
 */
public class FittingUtils {
    private static Logger logger = LoggerFactory.getLogger(FittingUtils.class);

    public static Long seed = null;

    private static final int MAX_ITER = 10000;
    private static final int MAX_EVAL = 1000000;
    private static final double REL_TOL = 1e-7;
    private static final double ABS_TOL = 1e-15;

    public enum Optimizer {
        Simplex, CMAES, BOBYQA
    }

    static Optimizer optimizer = Optimizer.CMAES;

    /**
     * @param n
     * @return optimizer
     */
    public static MultivariateOptimizer createOptimizer(int n) {
        return createOptimizer(optimizer, n);
    }

    /**
     * @param opt
     * @param n
     * @return optimizer
     */
    public static MultivariateOptimizer createOptimizer(Optimizer opt, int n) {
        switch (opt) {
        case BOBYQA:
            return new BOBYQAOptimizer(n + 2);
        case CMAES:
            return new CMAESOptimizer(MAX_ITER, 0., true, 0, 10,
                    seed == null ? new Well19937c() : new Well19937c(seed), false,
                    new SimplePointChecker<PointValuePair>(REL_TOL, ABS_TOL));
        case Simplex:
        default:
            return new SimplexOptimizer(new SimplePointChecker<PointValuePair>(REL_TOL * 1e3, ABS_TOL * 1e3));
        }
    }

    /**
     * Optimize given function
     * @param f
     * @param opt
     * @param min
     * @return residual
     */
    public static double optimize(FitFunction f, MultivariateOptimizer opt, double min) {
        double res = Double.NaN;
        try {
            PointValuePair result;

            if (opt instanceof BOBYQAOptimizer) {
                result = opt.optimize(new InitialGuess(f.getInitial()), GoalType.MINIMIZE, new ObjectiveFunction(f),
                        new MaxEval(MAX_EVAL), f.getBounds());
            } else if (opt instanceof CMAESOptimizer) {
                int p = (int) Math.ceil(4 + Math.log(f.getN())) + 1;
                logger.trace("Population size: {}", p);
                result = opt.optimize(new InitialGuess(f.getInitial()), GoalType.MINIMIZE, new ObjectiveFunction(f),
                        new CMAESOptimizer.Sigma(f.getSigma()), new CMAESOptimizer.PopulationSize(p),
                        new MaxEval(MAX_EVAL), f.getBounds());
            } else {
                int n = f.getN();
                double offset = 1e12;
                double[] scale = new double[n];
                for (int i = 0; i < n; i++) {
                    scale[i] = offset * 0.25;
                }
                SimpleBounds bnds = f.getBounds();
                MultivariateFunctionPenaltyAdapter of = new MultivariateFunctionPenaltyAdapter(f, bnds.getLower(),
                        bnds.getUpper(), offset, scale);
                result = opt.optimize(new InitialGuess(f.getInitial()), GoalType.MINIMIZE,
                        new ObjectiveFunction(of), new MaxEval(MAX_EVAL), new MultiDirectionalSimplex(n));
                //            new NelderMeadSimplex(n));
            }

            // logger.info("Q-space fit: rms = {}, x^2 = {}", opt.getRMS(), opt.getChiSquare());
            double ires = f.value(opt.getStartPoint());
            logger.trace("Residual: {} from {}", result.getValue(), ires);
            res = result.getValue();
            if (res < min)
                f.setParameters(result.getPoint());
            logger.trace("Used {} evals and {} iters", opt.getEvaluations(), opt.getIterations());
            //         System.err.printf("Used %d evals and %d iters\n", opt.getEvaluations(), opt.getIterations());
            // logger.info("Q-space fit: rms = {}, x^2 = {}", opt.getRMS(), opt.getChiSquare());
        } catch (IllegalArgumentException e) {
            logger.error("Start point has wrong dimension", e);
            // should not happen!
        } catch (TooManyEvaluationsException e) {
            throw new IllegalArgumentException("Could not fit as optimizer did not converge");
            //            logger.error("Convergence problem: max iterations ({}) exceeded", opt.getMaxIterations());
        }

        return res;
    }

    /**
     * Basic fit function interface
     */
    public interface FitFunction extends MultivariateFunction {
        /**
         * Set stored parameters
         * @param arg
         */
        public void setParameters(double[] arg);

        /**
         * @return stored parameters
         */
        public double[] getParameters();

        /**
         * @return standard deviations which may be used for sampling parameter space
         */
        public double[] getSigma();

        /**
         * @return bounds on parameters
         */
        public SimpleBounds getBounds();

        /**
         * @return initial parameters
         */
        public double[] getInitial();

        /**
         * Set initial parameters
         * @param init
         */
        public void setInitial(double... init);

        /**
         * @return number of parameters
         */
        public int getN();
    }

}