com.anhth12.nn.conf.NeuralNetworkConfiguration.java Source code

Java tutorial

Introduction

Here is the source code for com.anhth12.nn.conf.NeuralNetworkConfiguration.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package com.anhth12.nn.conf;

import com.anhth12.models.featuredetectors.rbm.RBM;
import com.anhth12.nn.api.LayerFactory;
import com.anhth12.nn.api.OptimizationAlgorithm;
import com.anhth12.optimize.api.IterationListener;
import com.anhth12.optimize.api.StepFunction;
import com.anhth12.optimize.stepfunction.GradientStepFunction;
import com.anhth12.weights.WeightInit;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.nd4j.linalg.api.activation.ActivationFunction;
import org.nd4j.linalg.api.activation.Activations;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/**
 *
 * @author anhth12
 */
public class NeuralNetworkConfiguration implements Serializable, Cloneable {

    private double sparsity = 0f;
    private boolean useAdaGrad = true;
    private double lr = 1e-1f;
    //corruption level
    private double corruptionLevel = 0.3f;
    private int numIterations;
    double momentum = 0.5f;
    double l2 = 0.0f;
    boolean useRegularization = false;
    Map<Integer, Double> momentumAfter = new HashMap<>();
    int resetAdaGradIterations = -1;
    double dropOut = 0;
    //apply sparsity
    WeightInit weightInit = WeightInit.VI;
    OptimizationAlgorithm optimizationAlgo = OptimizationAlgorithm.CONJUGATE_GRADIENT;
    LossFunctions.LossFunction lossFunction = LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY;
    //renderWeightsEveryNumEpochs
    boolean concateBias = false;
    //constrainGradientToUnitNorm
    protected boolean constrainGradientToUnitNorm = false;
    //seeds
    protected transient RandomGenerator rng;
    RealDistribution dist;
    protected transient Collection<IterationListener> listeners;
    protected transient StepFunction stepFunction = new GradientStepFunction();
    protected transient LayerFactory layerFactory;

    List<String> gradientList = new ArrayList<>();
    int nOut;
    int nIn;

    ActivationFunction activationFunction;
    private RBM.VisibleUnit visibleUnit = RBM.VisibleUnit.BINARY;
    private RBM.HiddenUnit hiddenUnit = RBM.HiddenUnit.BINARY;
    protected int k = 1;
    int batchSize;

    public void addVariable(String variable) {
        if (!gradientList.contains(variable)) {
            gradientList.add(variable);
        }
    }

    public NeuralNetworkConfiguration() {

    }

    public NeuralNetworkConfiguration(NeuralNetworkConfiguration conf) {
        this.layerFactory = conf.layerFactory;
        this.batchSize = conf.batchSize;
        this.sparsity = conf.sparsity;
        this.useAdaGrad = conf.useAdaGrad;
        this.lr = conf.lr;
        this.momentum = conf.momentum;
        this.l2 = conf.l2;
        this.numIterations = conf.numIterations;
        this.k = conf.k;
        this.corruptionLevel = conf.corruptionLevel;
        this.visibleUnit = conf.visibleUnit;
        this.hiddenUnit = conf.hiddenUnit;
        this.useRegularization = conf.useRegularization;
        this.momentumAfter = conf.momentumAfter;
        this.resetAdaGradIterations = conf.resetAdaGradIterations;
        this.dropOut = conf.dropOut;
        //        this.applySparsity = conf.applySparsity;
        this.weightInit = conf.weightInit;
        this.optimizationAlgo = conf.optimizationAlgo;
        this.lossFunction = conf.lossFunction;
        //        this.renderWeightsEveryNumEpochs = neuralNetConfiguration.renderWeightsEveryNumEpochs;
        this.concateBias = conf.concateBias;
        this.constrainGradientToUnitNorm = conf.constrainGradientToUnitNorm;
        this.rng = conf.rng;
        this.dist = conf.dist;
        //        this.seed = conf.seed;
        this.nIn = conf.nIn;
        this.nOut = conf.nOut;
        this.activationFunction = conf.activationFunction;
        this.visibleUnit = conf.visibleUnit;
        //        this.weightShape = neuralNetConfiguration.weightShape;
        //        this.stride = neuralNetConfiguration.stride;
        //        this.numFeatureMaps = neuralNetConfiguration.numFeatureMaps;
        //        this.filterSize = neuralNetConfiguration.filterSize;
        //        this.featureMapSize = neuralNetConfiguration.featureMapSize;
        if (dist == null) {
            this.dist = new NormalDistribution(rng, 0, .01, NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY);
        }

        this.hiddenUnit = conf.hiddenUnit;
    }

    /**
     * CLONE
     */
    @Override
    public NeuralNetworkConfiguration clone() {
        return new NeuralNetworkConfiguration(this);
    }

    /**
     * GETTER and SETTER
     */
    public List<String> getGradientList() {
        return gradientList;
    }

    public void setGradientList(List<String> gradientList) {
        this.gradientList = gradientList;
    }

    public int getnOut() {
        return nOut;
    }

    public void setnOut(int nOut) {
        this.nOut = nOut;
    }

    public int getnIn() {
        return nIn;
    }

    public void setnIn(int nIn) {
        this.nIn = nIn;
    }

    public WeightInit getWeightInit() {
        return weightInit;
    }

    public void setWeightInit(WeightInit weightInit) {
        this.weightInit = weightInit;
    }

    public ActivationFunction getActivationFunction() {
        return activationFunction;
    }

    public void setActivationFunction(ActivationFunction activationFunction) {
        this.activationFunction = activationFunction;
    }

    public RealDistribution getDist() {
        return dist;
    }

    public void setDist(RealDistribution dist) {
        this.dist = dist;
    }

    public boolean isConcateBias() {
        return concateBias;
    }

    public void setConcateBias(boolean concateBias) {
        this.concateBias = concateBias;
    }

    public int getBatchSize() {
        return batchSize;
    }

    public void setBatchSize(int batchSize) {
        this.batchSize = batchSize;
    }

    public LossFunctions.LossFunction getLossFunction() {
        return lossFunction;
    }

    public void setLossFunction(LossFunctions.LossFunction lossFunction) {
        this.lossFunction = lossFunction;
    }

    public double getL2() {
        return l2;
    }

    public void setL2(double l2) {
        this.l2 = l2;
    }

    public boolean isUseRegularization() {
        return useRegularization;
    }

    public void setUseRegularization(boolean useRegularization) {
        this.useRegularization = useRegularization;
    }

    public double getDropOut() {
        return dropOut;
    }

    public void setDropOut(double dropOut) {
        this.dropOut = dropOut;
    }

    public int getResetAdaGradIterations() {
        return resetAdaGradIterations;
    }

    public void setResetAdaGradIterations(int resetAdaGradIterations) {
        this.resetAdaGradIterations = resetAdaGradIterations;
    }

    public double getMomentum() {
        return momentum;
    }

    public void setMomentum(double momentum) {
        this.momentum = momentum;
    }

    public Map<Integer, Double> getMomentumAfter() {
        return momentumAfter;
    }

    public void setMomentumAfter(Map<Integer, Double> momentumAfter) {
        this.momentumAfter = momentumAfter;
    }

    public boolean isUseAdaGrad() {
        return useAdaGrad;
    }

    public void setUseAdaGrad(boolean useAdaGrad) {
        this.useAdaGrad = useAdaGrad;
    }

    public double getLr() {
        return lr;
    }

    public void setLr(double lr) {
        this.lr = lr;
    }

    public int getNumIterations() {
        return numIterations;
    }

    public void setNumIterations(int numIterations) {
        this.numIterations = numIterations;
    }

    public OptimizationAlgorithm getOptimizationAlgo() {
        return optimizationAlgo;
    }

    public void setOptimizationAlgo(OptimizationAlgorithm optimizationAlgo) {
        this.optimizationAlgo = optimizationAlgo;
    }

    public StepFunction getStepFunction() {
        return stepFunction;
    }

    public void setStepFunction(StepFunction stepFunction) {
        this.stepFunction = stepFunction;
    }

    public int getK() {
        return k;
    }

    public void setK(int k) {
        this.k = k;
    }

    public double getSparsity() {
        return sparsity;
    }

    public void setSparsity(double sparsity) {
        this.sparsity = sparsity;
    }

    public RBM.VisibleUnit getVisibleUnit() {
        return visibleUnit;
    }

    public void setVisibleUnit(RBM.VisibleUnit visibleUnit) {
        this.visibleUnit = visibleUnit;
    }

    public RBM.HiddenUnit getHiddenUnit() {
        return hiddenUnit;
    }

    public void setHiddenUnit(RBM.HiddenUnit hiddenUnit) {
        this.hiddenUnit = hiddenUnit;
    }

    public RandomGenerator getRng() {
        return rng;
    }

    public void setRng(RandomGenerator rng) {
        this.rng = rng;
    }

    public Collection<IterationListener> getListeners() {
        if (listeners == null) {
            listeners = new ArrayList<>();
        }
        return listeners;
    }

    public void setListeners(Collection<IterationListener> listeners) {
        this.listeners = listeners;
    }

    public double getCorruptionLevel() {
        return corruptionLevel;
    }

    public void setCorruptionLevel(double corruptionLevel) {
        this.corruptionLevel = corruptionLevel;
    }

    public LayerFactory getLayerFactory() {
        return layerFactory;
    }

    public void setLayerFactory(LayerFactory layerFactory) {
        this.layerFactory = layerFactory;
    }

    public boolean isConstrainGradientToUnitNorm() {
        return constrainGradientToUnitNorm;
    }

    public void setConstrainGradientToUnitNorm(boolean constrainGradientToUnitNorm) {
        this.constrainGradientToUnitNorm = constrainGradientToUnitNorm;
    }

    /**
     * Set the conf for classification
     *
     * @param conf
     * @param rows where to use softmax rows or softmax columns
     */
    public static void setClassifier(NeuralNetworkConfiguration conf, boolean rows) {
        conf.setActivationFunction(rows ? Activations.softMaxRows() : Activations.softmax());
        conf.setLossFunction(LossFunctions.LossFunction.MCXENT);
        conf.setWeightInit(WeightInit.ZERO);
    }

    public static void setClassifier(NeuralNetworkConfiguration conf) {
        setClassifier(conf, true);
    }

    public static interface ConfOverride {

        void override(int i, Builder builder);
    }

    public static class Builder {

        private int k = 1;
        private int kernel = 5;
        private double corruptionLevel = 3e-1f;
        private double sparsity = 0f;
        private boolean useAdaGrad = true;
        private double lr = 1e-1f;
        private double momentum = 0.5f;
        private double l2 = 0f;
        private boolean useRegularization = false;
        private Map<Integer, Double> momentumAfter;
        private int resetAdaGradIterations = -1;
        private double dropOut = 0;
        private boolean applySparsity = false;
        private WeightInit weightInit = WeightInit.VI;
        private OptimizationAlgorithm optimizationAlgo = OptimizationAlgorithm.CONJUGATE_GRADIENT;
        private int renderWeightsEveryNumEpochs = -1;
        private boolean concatBiases = false;
        private boolean constrainGradientToUnitNorm = false;
        private RandomGenerator rng = new MersenneTwister(123);
        private long seed = 123;
        private RealDistribution dist = new NormalDistribution(rng, 0, .01,
                NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY);
        private boolean adagrad = true;
        private LossFunctions.LossFunction lossFunction = LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY;
        private int nIn;
        private int nOut;
        private ActivationFunction activationFunction = Activations.sigmoid();
        private RBM.VisibleUnit visibleUnit = RBM.VisibleUnit.BINARY;
        private RBM.HiddenUnit hiddenUnit = RBM.HiddenUnit.BINARY;
        private int numIterations = 1000;
        private int[] weightShape;
        private int[] filterSize = { 2, 2, 2, 2 };
        private int[] featureMapSize = { 2, 2 };
        private int numInFeatureMaps = 2;
        //subsampling layers
        private int[] stride = { 2, 2 };
        private Collection<IterationListener> listeners;
        private StepFunction stepFunction = new GradientStepFunction();
        private LayerFactory layerFactory;
        private int batchSize = 0;

        public Builder batchSize(int batchSize) {
            this.batchSize = batchSize;
            return this;
        }

        public Builder kernel(int kernel) {
            this.kernel = kernel;
            return this;
        }

        public Builder layerFactory(LayerFactory layerFactory) {
            this.layerFactory = layerFactory;
            return this;
        }

        public Builder stepFunction(StepFunction stepFunction) {
            this.stepFunction = stepFunction;
            return this;
        }

        public ListBuilder list(int size) {
            if (size < 2) {
                throw new IllegalArgumentException("Number of layers must be > 1");
            }

            List<Builder> list = new ArrayList<>();
            for (int i = 0; i < size; i++) {
                list.add(clone());
            }
            return new ListBuilder(list);
        }

        public Builder clone() {
            return new Builder().activationFunction(activationFunction).layerFactory(layerFactory)
                    .adagradResetIterations(resetAdaGradIterations).applySparsity(applySparsity)
                    .concatBiases(concatBiases).constrainGradientToUnitNorm(constrainGradientToUnitNorm).dist(dist)
                    .dropOut(dropOut).featureMapSize(featureMapSize).filterSize(filterSize).hiddenUnit(hiddenUnit)
                    .iterations(numIterations).l2(l2).learningRate(lr).useAdaGrad(adagrad)
                    .stepFunction(stepFunction).lossFunction(lossFunction).momentumAfter(momentumAfter)
                    .momentum(momentum).listeners(listeners).nIn(nIn).nOut(nOut).optimizationAlgo(optimizationAlgo)
                    .batchSize(batchSize).regularization(useRegularization).render(renderWeightsEveryNumEpochs)
                    .resetAdaGradIterations(resetAdaGradIterations).rng(rng).seed(seed).sparsity(sparsity)
                    .stride(stride).useAdaGrad(useAdaGrad).visibleUnit(visibleUnit).weightInit(weightInit)
                    .weightShape(weightShape);
        }

        public Builder iterationListener(IterationListener listener) {
            if (listeners != null) {
                listeners.add(listener);
            } else {
                listeners = new ArrayList<>();
                listeners.add(listener);
            }

            return this;
        }

        public Builder listeners(Collection<IterationListener> listeners) {
            this.listeners = listeners;
            return this;
        }

        public Builder featureMapSize(int[] featureMapSize) {
            this.featureMapSize = featureMapSize;
            return this;
        }

        public Builder stride(int[] stride) {
            this.stride = stride;
            return this;
        }

        public Builder filterSize(int... filterSize) {
            if (filterSize == null) {
                return this;
            }
            if (filterSize.length != 4) {
                throw new IllegalArgumentException("Invalid filter size must be length 2");
            }
            this.filterSize = filterSize;
            return this;
        }

        public Builder weightShape(int[] weightShape) {
            this.weightShape = weightShape;
            return this;
        }

        public Builder iterations(int numIterations) {
            this.numIterations = numIterations;
            return this;
        }

        public Builder dist(RealDistribution dist) {
            this.dist = dist;
            return this;
        }

        public Builder sparsity(double sparsity) {
            this.sparsity = sparsity;
            return this;
        }

        public Builder useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder learningRate(double lr) {
            this.lr = lr;
            return this;
        }

        public Builder momentum(double momentum) {
            this.momentum = momentum;
            return this;
        }

        public Builder k(int k) {
            this.k = k;
            return this;
        }

        public Builder corruptionLevel(double corruptionLevel) {
            this.corruptionLevel = corruptionLevel;
            return this;
        }

        public Builder momentumAfter(Map<Integer, Double> momentumAfter) {
            this.momentumAfter = momentumAfter;
            return this;
        }

        public Builder adagradResetIterations(int resetAdaGradIterations) {
            this.resetAdaGradIterations = resetAdaGradIterations;
            return this;
        }

        public Builder dropOut(double dropOut) {
            this.dropOut = dropOut;
            return this;
        }

        public Builder applySparsity(boolean applySparsity) {
            this.applySparsity = applySparsity;
            return this;
        }

        public Builder weightInit(WeightInit weightInit) {
            this.weightInit = weightInit;
            return this;
        }

        public Builder render(int renderWeightsEveryNumEpochs) {
            this.renderWeightsEveryNumEpochs = renderWeightsEveryNumEpochs;
            return this;
        }

        public Builder concatBiases(boolean concatBiases) {
            this.concatBiases = concatBiases;
            return this;
        }

        public Builder rng(RandomGenerator rng) {
            this.rng = rng;
            return this;
        }

        public Builder seed(long seed) {
            this.seed = seed;
            return this;
        }

        public Builder l2(double l2) {
            this.l2 = l2;
            return this;
        }

        public Builder regularization(boolean useRegularization) {
            this.useRegularization = useRegularization;
            return this;
        }

        public Builder resetAdaGradIterations(int resetAdaGradIterations) {
            this.resetAdaGradIterations = resetAdaGradIterations;
            return this;
        }

        public Builder optimizationAlgo(OptimizationAlgorithm optimizationAlgo) {
            this.optimizationAlgo = optimizationAlgo;
            return this;
        }

        public Builder lossFunction(LossFunctions.LossFunction lossFunction) {
            this.lossFunction = lossFunction;
            return this;
        }

        public Builder constrainGradientToUnitNorm(boolean constrainGradientToUnitNorm) {
            this.constrainGradientToUnitNorm = constrainGradientToUnitNorm;
            return this;
        }

        public Builder nIn(int nIn) {
            this.nIn = nIn;
            return this;
        }

        public Builder nOut(int nOut) {
            this.nOut = nOut;
            return this;
        }

        public Builder activationFunction(ActivationFunction activationFunction) {
            this.activationFunction = activationFunction;
            return this;
        }

        public Builder visibleUnit(RBM.VisibleUnit visibleUnit) {
            this.visibleUnit = visibleUnit;
            return this;
        }

        public Builder hiddenUnit(RBM.HiddenUnit hiddenUnit) {
            this.hiddenUnit = hiddenUnit;
            return this;
        }

        //Build

        public NeuralNetworkConfiguration build() {
            NeuralNetworkConfiguration ret = new NeuralNetworkConfiguration();
            ret.activationFunction = activationFunction;
            ret.concateBias = concatBiases;
            ret.constrainGradientToUnitNorm = constrainGradientToUnitNorm;
            ret.corruptionLevel = corruptionLevel;
            ret.dist = dist;
            ret.dropOut = dropOut;
            ret.dist = dist;
            //            ret.gradientList = 
            ret.hiddenUnit = hiddenUnit;
            ret.k = k;
            ret.l2 = l2;
            ret.lr = lr;
            ret.layerFactory = layerFactory;
            ret.listeners = listeners;
            ret.lossFunction = lossFunction;
            ret.lr = lr;
            ret.momentum = momentum;
            ret.momentumAfter = momentumAfter;
            ret.nIn = nIn;
            ret.nOut = nOut;
            ret.numIterations = numIterations;
            ret.optimizationAlgo = optimizationAlgo;
            ret.resetAdaGradIterations = resetAdaGradIterations;
            ret.rng = rng;
            ret.sparsity = sparsity;
            ret.stepFunction = stepFunction;
            ret.useRegularization = useRegularization;
            ret.visibleUnit = visibleUnit;
            ret.weightInit = weightInit;
            ret.useAdaGrad = this.adagrad;

            return ret;
        }
    }

    public static class ListBuilder {

        private List<Builder> layerwise;
        private int[] hiddenLayerSizes;
        private boolean useDropConnect = false;
        private boolean pretrain = true;

        private Map<Integer, OutputPreprocessor> preProcessors = new HashMap<>();

        public ListBuilder(List<Builder> list) {
            this.layerwise = list;
        }

        public ListBuilder preProcessor(ConfOverride override) {
            this.preProcessors = preProcessors;
            return this;
        }

        public ListBuilder pretrain(boolean pretrain) {
            this.pretrain = pretrain;
            return this;
        }

        public ListBuilder useDropConnect(boolean useDropConnect) {
            this.useDropConnect = useDropConnect;
            return this;
        }

        public ListBuilder override(ConfOverride override) {
            for (int i = 0; i < layerwise.size(); i++) {
                override.override(i, layerwise.get(i));
            }
            return this;
        }

        public ListBuilder hiddenLayerSizes(int... hiddenLayerSizes) {
            this.hiddenLayerSizes = hiddenLayerSizes;
            return this;
        }

        public MultiLayerConfiguration build() {
            if (layerwise.size() != hiddenLayerSizes.length + 1) {
                throw new IllegalStateException("Number of hidden layers must be equal to hidden layer size + 1");
            }

            List<NeuralNetworkConfiguration> list = new ArrayList<>();

            for (int i = 0; i < layerwise.size(); i++) {
                list.add(layerwise.get(i).build());
            }

            return new MultiLayerConfiguration.Builder().useDropConnect(useDropConnect).pretrain(pretrain)
                    .preProcessors(preProcessors).hiddenLayerSizes(hiddenLayerSizes).confs(list).build();
        }

    }

}