Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package com.anhth12.nn.conf; import com.anhth12.models.featuredetectors.rbm.RBM; import com.anhth12.nn.api.LayerFactory; import com.anhth12.nn.api.OptimizationAlgorithm; import com.anhth12.optimize.api.IterationListener; import com.anhth12.optimize.api.StepFunction; import com.anhth12.optimize.stepfunction.GradientStepFunction; import com.anhth12.weights.WeightInit; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.commons.math3.distribution.NormalDistribution; import org.apache.commons.math3.distribution.RealDistribution; import org.apache.commons.math3.random.MersenneTwister; import org.apache.commons.math3.random.RandomGenerator; import org.nd4j.linalg.api.activation.ActivationFunction; import org.nd4j.linalg.api.activation.Activations; import org.nd4j.linalg.lossfunctions.LossFunctions; /** * * @author anhth12 */ public class NeuralNetworkConfiguration implements Serializable, Cloneable { private double sparsity = 0f; private boolean useAdaGrad = true; private double lr = 1e-1f; //corruption level private double corruptionLevel = 0.3f; private int numIterations; double momentum = 0.5f; double l2 = 0.0f; boolean useRegularization = false; Map<Integer, Double> momentumAfter = new HashMap<>(); int resetAdaGradIterations = -1; double dropOut = 0; //apply sparsity WeightInit weightInit = WeightInit.VI; OptimizationAlgorithm optimizationAlgo = OptimizationAlgorithm.CONJUGATE_GRADIENT; LossFunctions.LossFunction lossFunction = LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY; //renderWeightsEveryNumEpochs boolean concateBias = false; //constrainGradientToUnitNorm protected boolean constrainGradientToUnitNorm = false; //seeds protected transient RandomGenerator rng; RealDistribution dist; protected transient Collection<IterationListener> listeners; protected transient StepFunction stepFunction = new GradientStepFunction(); protected transient LayerFactory layerFactory; List<String> gradientList = new ArrayList<>(); int nOut; int nIn; ActivationFunction activationFunction; private RBM.VisibleUnit visibleUnit = RBM.VisibleUnit.BINARY; private RBM.HiddenUnit hiddenUnit = RBM.HiddenUnit.BINARY; protected int k = 1; int batchSize; public void addVariable(String variable) { if (!gradientList.contains(variable)) { gradientList.add(variable); } } public NeuralNetworkConfiguration() { } public NeuralNetworkConfiguration(NeuralNetworkConfiguration conf) { this.layerFactory = conf.layerFactory; this.batchSize = conf.batchSize; this.sparsity = conf.sparsity; this.useAdaGrad = conf.useAdaGrad; this.lr = conf.lr; this.momentum = conf.momentum; this.l2 = conf.l2; this.numIterations = conf.numIterations; this.k = conf.k; this.corruptionLevel = conf.corruptionLevel; this.visibleUnit = conf.visibleUnit; this.hiddenUnit = conf.hiddenUnit; this.useRegularization = conf.useRegularization; this.momentumAfter = conf.momentumAfter; this.resetAdaGradIterations = conf.resetAdaGradIterations; this.dropOut = conf.dropOut; // this.applySparsity = conf.applySparsity; this.weightInit = conf.weightInit; this.optimizationAlgo = conf.optimizationAlgo; this.lossFunction = conf.lossFunction; // this.renderWeightsEveryNumEpochs = neuralNetConfiguration.renderWeightsEveryNumEpochs; this.concateBias = conf.concateBias; this.constrainGradientToUnitNorm = conf.constrainGradientToUnitNorm; this.rng = conf.rng; this.dist = conf.dist; // this.seed = conf.seed; this.nIn = conf.nIn; this.nOut = conf.nOut; this.activationFunction = conf.activationFunction; this.visibleUnit = conf.visibleUnit; // this.weightShape = neuralNetConfiguration.weightShape; // this.stride = neuralNetConfiguration.stride; // this.numFeatureMaps = neuralNetConfiguration.numFeatureMaps; // this.filterSize = neuralNetConfiguration.filterSize; // this.featureMapSize = neuralNetConfiguration.featureMapSize; if (dist == null) { this.dist = new NormalDistribution(rng, 0, .01, NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); } this.hiddenUnit = conf.hiddenUnit; } /** * CLONE */ @Override public NeuralNetworkConfiguration clone() { return new NeuralNetworkConfiguration(this); } /** * GETTER and SETTER */ public List<String> getGradientList() { return gradientList; } public void setGradientList(List<String> gradientList) { this.gradientList = gradientList; } public int getnOut() { return nOut; } public void setnOut(int nOut) { this.nOut = nOut; } public int getnIn() { return nIn; } public void setnIn(int nIn) { this.nIn = nIn; } public WeightInit getWeightInit() { return weightInit; } public void setWeightInit(WeightInit weightInit) { this.weightInit = weightInit; } public ActivationFunction getActivationFunction() { return activationFunction; } public void setActivationFunction(ActivationFunction activationFunction) { this.activationFunction = activationFunction; } public RealDistribution getDist() { return dist; } public void setDist(RealDistribution dist) { this.dist = dist; } public boolean isConcateBias() { return concateBias; } public void setConcateBias(boolean concateBias) { this.concateBias = concateBias; } public int getBatchSize() { return batchSize; } public void setBatchSize(int batchSize) { this.batchSize = batchSize; } public LossFunctions.LossFunction getLossFunction() { return lossFunction; } public void setLossFunction(LossFunctions.LossFunction lossFunction) { this.lossFunction = lossFunction; } public double getL2() { return l2; } public void setL2(double l2) { this.l2 = l2; } public boolean isUseRegularization() { return useRegularization; } public void setUseRegularization(boolean useRegularization) { this.useRegularization = useRegularization; } public double getDropOut() { return dropOut; } public void setDropOut(double dropOut) { this.dropOut = dropOut; } public int getResetAdaGradIterations() { return resetAdaGradIterations; } public void setResetAdaGradIterations(int resetAdaGradIterations) { this.resetAdaGradIterations = resetAdaGradIterations; } public double getMomentum() { return momentum; } public void setMomentum(double momentum) { this.momentum = momentum; } public Map<Integer, Double> getMomentumAfter() { return momentumAfter; } public void setMomentumAfter(Map<Integer, Double> momentumAfter) { this.momentumAfter = momentumAfter; } public boolean isUseAdaGrad() { return useAdaGrad; } public void setUseAdaGrad(boolean useAdaGrad) { this.useAdaGrad = useAdaGrad; } public double getLr() { return lr; } public void setLr(double lr) { this.lr = lr; } public int getNumIterations() { return numIterations; } public void setNumIterations(int numIterations) { this.numIterations = numIterations; } public OptimizationAlgorithm getOptimizationAlgo() { return optimizationAlgo; } public void setOptimizationAlgo(OptimizationAlgorithm optimizationAlgo) { this.optimizationAlgo = optimizationAlgo; } public StepFunction getStepFunction() { return stepFunction; } public void setStepFunction(StepFunction stepFunction) { this.stepFunction = stepFunction; } public int getK() { return k; } public void setK(int k) { this.k = k; } public double getSparsity() { return sparsity; } public void setSparsity(double sparsity) { this.sparsity = sparsity; } public RBM.VisibleUnit getVisibleUnit() { return visibleUnit; } public void setVisibleUnit(RBM.VisibleUnit visibleUnit) { this.visibleUnit = visibleUnit; } public RBM.HiddenUnit getHiddenUnit() { return hiddenUnit; } public void setHiddenUnit(RBM.HiddenUnit hiddenUnit) { this.hiddenUnit = hiddenUnit; } public RandomGenerator getRng() { return rng; } public void setRng(RandomGenerator rng) { this.rng = rng; } public Collection<IterationListener> getListeners() { if (listeners == null) { listeners = new ArrayList<>(); } return listeners; } public void setListeners(Collection<IterationListener> listeners) { this.listeners = listeners; } public double getCorruptionLevel() { return corruptionLevel; } public void setCorruptionLevel(double corruptionLevel) { this.corruptionLevel = corruptionLevel; } public LayerFactory getLayerFactory() { return layerFactory; } public void setLayerFactory(LayerFactory layerFactory) { this.layerFactory = layerFactory; } public boolean isConstrainGradientToUnitNorm() { return constrainGradientToUnitNorm; } public void setConstrainGradientToUnitNorm(boolean constrainGradientToUnitNorm) { this.constrainGradientToUnitNorm = constrainGradientToUnitNorm; } /** * Set the conf for classification * * @param conf * @param rows where to use softmax rows or softmax columns */ public static void setClassifier(NeuralNetworkConfiguration conf, boolean rows) { conf.setActivationFunction(rows ? Activations.softMaxRows() : Activations.softmax()); conf.setLossFunction(LossFunctions.LossFunction.MCXENT); conf.setWeightInit(WeightInit.ZERO); } public static void setClassifier(NeuralNetworkConfiguration conf) { setClassifier(conf, true); } public static interface ConfOverride { void override(int i, Builder builder); } public static class Builder { private int k = 1; private int kernel = 5; private double corruptionLevel = 3e-1f; private double sparsity = 0f; private boolean useAdaGrad = true; private double lr = 1e-1f; private double momentum = 0.5f; private double l2 = 0f; private boolean useRegularization = false; private Map<Integer, Double> momentumAfter; private int resetAdaGradIterations = -1; private double dropOut = 0; private boolean applySparsity = false; private WeightInit weightInit = WeightInit.VI; private OptimizationAlgorithm optimizationAlgo = OptimizationAlgorithm.CONJUGATE_GRADIENT; private int renderWeightsEveryNumEpochs = -1; private boolean concatBiases = false; private boolean constrainGradientToUnitNorm = false; private RandomGenerator rng = new MersenneTwister(123); private long seed = 123; private RealDistribution dist = new NormalDistribution(rng, 0, .01, NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); private boolean adagrad = true; private LossFunctions.LossFunction lossFunction = LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY; private int nIn; private int nOut; private ActivationFunction activationFunction = Activations.sigmoid(); private RBM.VisibleUnit visibleUnit = RBM.VisibleUnit.BINARY; private RBM.HiddenUnit hiddenUnit = RBM.HiddenUnit.BINARY; private int numIterations = 1000; private int[] weightShape; private int[] filterSize = { 2, 2, 2, 2 }; private int[] featureMapSize = { 2, 2 }; private int numInFeatureMaps = 2; //subsampling layers private int[] stride = { 2, 2 }; private Collection<IterationListener> listeners; private StepFunction stepFunction = new GradientStepFunction(); private LayerFactory layerFactory; private int batchSize = 0; public Builder batchSize(int batchSize) { this.batchSize = batchSize; return this; } public Builder kernel(int kernel) { this.kernel = kernel; return this; } public Builder layerFactory(LayerFactory layerFactory) { this.layerFactory = layerFactory; return this; } public Builder stepFunction(StepFunction stepFunction) { this.stepFunction = stepFunction; return this; } public ListBuilder list(int size) { if (size < 2) { throw new IllegalArgumentException("Number of layers must be > 1"); } List<Builder> list = new ArrayList<>(); for (int i = 0; i < size; i++) { list.add(clone()); } return new ListBuilder(list); } public Builder clone() { return new Builder().activationFunction(activationFunction).layerFactory(layerFactory) .adagradResetIterations(resetAdaGradIterations).applySparsity(applySparsity) .concatBiases(concatBiases).constrainGradientToUnitNorm(constrainGradientToUnitNorm).dist(dist) .dropOut(dropOut).featureMapSize(featureMapSize).filterSize(filterSize).hiddenUnit(hiddenUnit) .iterations(numIterations).l2(l2).learningRate(lr).useAdaGrad(adagrad) .stepFunction(stepFunction).lossFunction(lossFunction).momentumAfter(momentumAfter) .momentum(momentum).listeners(listeners).nIn(nIn).nOut(nOut).optimizationAlgo(optimizationAlgo) .batchSize(batchSize).regularization(useRegularization).render(renderWeightsEveryNumEpochs) .resetAdaGradIterations(resetAdaGradIterations).rng(rng).seed(seed).sparsity(sparsity) .stride(stride).useAdaGrad(useAdaGrad).visibleUnit(visibleUnit).weightInit(weightInit) .weightShape(weightShape); } public Builder iterationListener(IterationListener listener) { if (listeners != null) { listeners.add(listener); } else { listeners = new ArrayList<>(); listeners.add(listener); } return this; } public Builder listeners(Collection<IterationListener> listeners) { this.listeners = listeners; return this; } public Builder featureMapSize(int[] featureMapSize) { this.featureMapSize = featureMapSize; return this; } public Builder stride(int[] stride) { this.stride = stride; return this; } public Builder filterSize(int... filterSize) { if (filterSize == null) { return this; } if (filterSize.length != 4) { throw new IllegalArgumentException("Invalid filter size must be length 2"); } this.filterSize = filterSize; return this; } public Builder weightShape(int[] weightShape) { this.weightShape = weightShape; return this; } public Builder iterations(int numIterations) { this.numIterations = numIterations; return this; } public Builder dist(RealDistribution dist) { this.dist = dist; return this; } public Builder sparsity(double sparsity) { this.sparsity = sparsity; return this; } public Builder useAdaGrad(boolean useAdaGrad) { this.useAdaGrad = useAdaGrad; return this; } public Builder learningRate(double lr) { this.lr = lr; return this; } public Builder momentum(double momentum) { this.momentum = momentum; return this; } public Builder k(int k) { this.k = k; return this; } public Builder corruptionLevel(double corruptionLevel) { this.corruptionLevel = corruptionLevel; return this; } public Builder momentumAfter(Map<Integer, Double> momentumAfter) { this.momentumAfter = momentumAfter; return this; } public Builder adagradResetIterations(int resetAdaGradIterations) { this.resetAdaGradIterations = resetAdaGradIterations; return this; } public Builder dropOut(double dropOut) { this.dropOut = dropOut; return this; } public Builder applySparsity(boolean applySparsity) { this.applySparsity = applySparsity; return this; } public Builder weightInit(WeightInit weightInit) { this.weightInit = weightInit; return this; } public Builder render(int renderWeightsEveryNumEpochs) { this.renderWeightsEveryNumEpochs = renderWeightsEveryNumEpochs; return this; } public Builder concatBiases(boolean concatBiases) { this.concatBiases = concatBiases; return this; } public Builder rng(RandomGenerator rng) { this.rng = rng; return this; } public Builder seed(long seed) { this.seed = seed; return this; } public Builder l2(double l2) { this.l2 = l2; return this; } public Builder regularization(boolean useRegularization) { this.useRegularization = useRegularization; return this; } public Builder resetAdaGradIterations(int resetAdaGradIterations) { this.resetAdaGradIterations = resetAdaGradIterations; return this; } public Builder optimizationAlgo(OptimizationAlgorithm optimizationAlgo) { this.optimizationAlgo = optimizationAlgo; return this; } public Builder lossFunction(LossFunctions.LossFunction lossFunction) { this.lossFunction = lossFunction; return this; } public Builder constrainGradientToUnitNorm(boolean constrainGradientToUnitNorm) { this.constrainGradientToUnitNorm = constrainGradientToUnitNorm; return this; } public Builder nIn(int nIn) { this.nIn = nIn; return this; } public Builder nOut(int nOut) { this.nOut = nOut; return this; } public Builder activationFunction(ActivationFunction activationFunction) { this.activationFunction = activationFunction; return this; } public Builder visibleUnit(RBM.VisibleUnit visibleUnit) { this.visibleUnit = visibleUnit; return this; } public Builder hiddenUnit(RBM.HiddenUnit hiddenUnit) { this.hiddenUnit = hiddenUnit; return this; } //Build public NeuralNetworkConfiguration build() { NeuralNetworkConfiguration ret = new NeuralNetworkConfiguration(); ret.activationFunction = activationFunction; ret.concateBias = concatBiases; ret.constrainGradientToUnitNorm = constrainGradientToUnitNorm; ret.corruptionLevel = corruptionLevel; ret.dist = dist; ret.dropOut = dropOut; ret.dist = dist; // ret.gradientList = ret.hiddenUnit = hiddenUnit; ret.k = k; ret.l2 = l2; ret.lr = lr; ret.layerFactory = layerFactory; ret.listeners = listeners; ret.lossFunction = lossFunction; ret.lr = lr; ret.momentum = momentum; ret.momentumAfter = momentumAfter; ret.nIn = nIn; ret.nOut = nOut; ret.numIterations = numIterations; ret.optimizationAlgo = optimizationAlgo; ret.resetAdaGradIterations = resetAdaGradIterations; ret.rng = rng; ret.sparsity = sparsity; ret.stepFunction = stepFunction; ret.useRegularization = useRegularization; ret.visibleUnit = visibleUnit; ret.weightInit = weightInit; ret.useAdaGrad = this.adagrad; return ret; } } public static class ListBuilder { private List<Builder> layerwise; private int[] hiddenLayerSizes; private boolean useDropConnect = false; private boolean pretrain = true; private Map<Integer, OutputPreprocessor> preProcessors = new HashMap<>(); public ListBuilder(List<Builder> list) { this.layerwise = list; } public ListBuilder preProcessor(ConfOverride override) { this.preProcessors = preProcessors; return this; } public ListBuilder pretrain(boolean pretrain) { this.pretrain = pretrain; return this; } public ListBuilder useDropConnect(boolean useDropConnect) { this.useDropConnect = useDropConnect; return this; } public ListBuilder override(ConfOverride override) { for (int i = 0; i < layerwise.size(); i++) { override.override(i, layerwise.get(i)); } return this; } public ListBuilder hiddenLayerSizes(int... hiddenLayerSizes) { this.hiddenLayerSizes = hiddenLayerSizes; return this; } public MultiLayerConfiguration build() { if (layerwise.size() != hiddenLayerSizes.length + 1) { throw new IllegalStateException("Number of hidden layers must be equal to hidden layer size + 1"); } List<NeuralNetworkConfiguration> list = new ArrayList<>(); for (int i = 0; i < layerwise.size(); i++) { list.add(layerwise.get(i).build()); } return new MultiLayerConfiguration.Builder().useDropConnect(useDropConnect).pretrain(pretrain) .preProcessors(preProcessors).hiddenLayerSizes(hiddenLayerSizes).confs(list).build(); } } }