playground.thibautd.initialdemandgeneration.socnetgensimulated.framework.ModelIterator.java Source code

Java tutorial

Introduction

Here is the source code for playground.thibautd.initialdemandgeneration.socnetgensimulated.framework.ModelIterator.java

Source

/* *********************************************************************** *
 * project: org.matsim.*
 * ModelIterator.java
 *                                                                         *
 * *********************************************************************** *
 *                                                                         *
 * copyright       : (C) 2015 by the members listed in the COPYING,        *
 *                   LICENSE and WARRANTY file.                            *
 * email           : info at matsim dot org                                *
 *                                                                         *
 * *********************************************************************** *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *   See also COPYING, LICENSE and WARRANTY file                           *
 *                                                                         *
 * *********************************************************************** */
package playground.thibautd.initialdemandgeneration.socnetgensimulated.framework;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleBounds;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.log4j.Logger;

import playground.thibautd.initialdemandgeneration.socnetgen.framework.SnaUtils;
import org.matsim.contrib.socnetsim.framework.population.SocialNetwork;

/**
 * @author thibautd
 */
public class ModelIterator {
    private static final Logger log = Logger.getLogger(ModelIterator.class);

    private final double targetClustering;
    private final double targetDegree;

    private final double precisionClustering;
    private final double precisionDegree;

    private final int maxIterations;

    private final double powellMinAbsoluteChange;
    private final double powellMinRelativeChange;

    private final int nThreads = 4;

    private final List<EvolutionListener> listeners = new ArrayList<>();

    public ModelIterator(final SocialNetworkGenerationConfigGroup config) {
        this.targetClustering = config.getTargetClustering();
        this.targetDegree = config.getTargetDegree();

        listeners.add(new EvolutionLogger());

        this.precisionClustering = config.getPrecisionClustering();
        this.precisionDegree = config.getPrecisionDegree();

        this.powellMinAbsoluteChange = config.getPowellMinAbsoluteChange();
        this.powellMinRelativeChange = config.getPowellMinRelativeChange();

        this.maxIterations = config.getMaxIterations();
    }

    private SocialNetwork generate(final ModelRunner runner, final Thresholds thresholds) {
        log.info("generate network for " + thresholds);
        final long start = System.currentTimeMillis();
        final SocialNetwork sn = runner.runModel(thresholds);

        thresholds.setResultingAverageDegree(SnaUtils.calcAveragePersonalNetworkSize(sn));
        thresholds.setResultingClustering(estimateClustering(sn));

        log.info("generation took " + (System.currentTimeMillis() - start) + " ms");
        return sn;
    }

    public SocialNetwork iterateModelToTarget(final ModelRunner runner, final Thresholds initialThresholds) {
        final MultivariateOptimizer optimizer = new CMAESOptimizer(maxIterations, 1E-9, true, 3, 50,
                new MersenneTwister(42), false, new Convergence());

        final double x = initialThresholds.getPrimaryThreshold();
        final double y = initialThresholds.getSecondaryReduction();

        final PointValuePair result = optimizer.optimize(GoalType.MINIMIZE, new MaxEval(maxIterations),
                new InitialGuess(new double[] { x, y }), new ObjectiveFunction(new Function(runner)),
                new CMAESOptimizer.Sigma(new double[] { 5, 500 }), new CMAESOptimizer.PopulationSize(7),
                new SimpleBounds(new double[] { Double.NEGATIVE_INFINITY, 0 }, // lower bounds: constrain secondary reduction to be positive
                        new double[] { Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY }) // upper bounds
        );

        final Thresholds bestThresholds = new Thresholds(result.getPoint()[0], result.getPoint()[1]);
        final SocialNetwork bestSn = generate(runner, bestThresholds);

        log.info("best social network found for thresholds: " + bestThresholds);

        return bestSn;
    }

    private double estimateClustering(final SocialNetwork sn) {
        //final double estimate = SnaUtils.estimateClusteringCoefficient( samplingRateClustering , sn );

        //return Math.abs( targetClustering - estimate ) > 10 * precisionClustering ? estimate : SnaUtils.calcClusteringCoefficient( sn );
        return SnaUtils.estimateClusteringCoefficient(1900, nThreads, precisionClustering, 0.95, sn);
    }

    public void addListener(final EvolutionListener l) {
        listeners.add(l);
    }

    private double distClustering(final Thresholds thresholds) {
        return Math.abs(targetClustering - thresholds.getResultingClustering());
    }

    private double distDegree(final Thresholds thresholds) {
        return Math.abs(targetDegree - thresholds.getResultingAverageDegree());
    }

    public static interface EvolutionListener {
        public void handleMove(Thresholds m, double fitness);
    }

    private static class EvolutionLogger implements EvolutionListener {
        @Override
        public void handleMove(final Thresholds m, final double fitness) {
            log.info("generated network for " + m + " -> fitness=" + fitness);
        }
    }

    private class Function implements MultivariateFunction {
        private final ModelRunner runner;

        public Function(ModelRunner runner) {
            this.runner = runner;
        }

        @Override
        public double value(final double[] args) {
            final Thresholds thr = new Thresholds(args[0], args[1]);
            generate(runner, thr);

            final double fitness = Math.pow(distDegree(thr) / precisionDegree, 10)
                    + Math.pow(distClustering(thr) / precisionClustering, 10);

            for (EvolutionListener l : listeners)
                l.handleMove(thr, fitness);

            return fitness;
        }
    }

    private class Convergence implements ConvergenceChecker<PointValuePair> {
        @Override
        public boolean converged(final int i, final PointValuePair prev, final PointValuePair curr) {
            // not really satisfying...
            final boolean conv = curr.getValue().doubleValue() < 1
                    || Math.abs(prev.getValue() - curr.getValue()) < 1E-8;

            if (conv) {
                log.info("convergence checker considers convergenced is reached.");
            } else {
                final double prevVal = prev.getValue();
                final double currVal = curr.getValue();

                final double abs = Math.abs(currVal - prevVal);
                final double rel = Math.abs((currVal - prevVal) / prevVal);

                if (abs <= powellMinAbsoluteChange || rel <= powellMinRelativeChange) {
                    // never printed: optimizer calls this class after line search,
                    // if and only if the optimizer does not itself considers it converged...
                    log.warn("considered non converged, but too flat: Powell will abort!");
                    log.warn("absolute change: " + abs);
                    log.warn("relative change: " + rel);
                }
            }

            return conv;
        }
    }
}