mulan.classifier.neural.BPMLL.java Source code

Java tutorial

Introduction

Here is the source code for mulan.classifier.neural.BPMLL.java

Source

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    BPMLL.java
 *    Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
 */
package mulan.classifier.neural;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.Set;

import mulan.classifier.InvalidDataException;
import mulan.classifier.MultiLabelLearnerBase;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.neural.model.ActivationTANH;
import mulan.classifier.neural.model.BasicNeuralNet;
import mulan.classifier.neural.model.NeuralNet;
import mulan.core.WekaException;
import mulan.data.DataUtils;
import mulan.data.InvalidDataFormatException;
import mulan.data.MultiLabelInstances;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;

/**
 * The implementation of Back-Propagation Multi-Label Learning (BPMLL) learner.
 * The learned model is stored in {@link NeuralNet} neural network. The models of the
 * learner is built by {@link BPMLLAlgorithm} from given training data set.
 *
 * <!-- technical-bibtex-start -->
 *
 * <!-- technical-bibtex-end -->
 *
 * @author Jozef Vilcek
 * @see BPMLLAlgorithm
 */
public class BPMLL extends MultiLabelLearnerBase {

    private static final long serialVersionUID = 2153814250172139021L;
    private static final double NET_BIAS = 1;
    private static final double ERROR_SMALL_CHANGE = 0.000001;
    // filter used to convert nominal input attributes into binary-numeric
    private NominalToBinary nominalToBinaryFilter;
    // algorithm parameters
    private int epochs = 100;
    private final Long randomnessSeed;
    private double weightsDecayCost = 0.00001;
    private double learningRate = 0.05;
    private int[] hiddenLayersTopology;
    // members related to normalization or attributes
    private boolean normalizeAttributes = true;
    private NormalizationFilter normalizer;
    private NeuralNet model;
    private ThresholdFunction thresholdF;

    /**
     * Creates a new instance of {@link BPMLL} learner.
     */
    public BPMLL() {
        randomnessSeed = null;
    }

    /**
     * Creates a new instance of {@link BPMLL} learner.
     * @param randomnessSeed the seed value for pseudo-random generator
     */
    public BPMLL(long randomnessSeed) {
        this.randomnessSeed = randomnessSeed;
    }

    /**
     * Sets the topology of hidden layers for neural network.
     * The length of passed array defines number of hidden layers.
     * The value at particular index of array defines number of neurons in that layer.
     * If <code>null</code> is specified, no hidden layers will be created.
     * <br/>
     * The network is created when learner is being built.
     * The input and output layer is determined from input training data.
     *
     * @param hiddenLayers
     * @throws IllegalArgumentException if any value in the array is less or equal to zero
     */
    public void setHiddenLayers(int[] hiddenLayers) {
        if (hiddenLayers != null) {
            for (int value : hiddenLayers) {
                if (value <= 0) {
                    throw new IllegalArgumentException("Invalid hidden layer topology definition. "
                            + "Number of neurons in hidden layer must be larger than zero.");
                }
            }
        }
        hiddenLayersTopology = hiddenLayers;
    }

    /**
     * Gets an array defining topology of hidden layer of the underlying neural model.
     *
     * @return The method returns a copy of the array.
     */
    public int[] getHiddenLayers() {
        return hiddenLayersTopology == null ? hiddenLayersTopology
                : Arrays.copyOf(hiddenLayersTopology, hiddenLayersTopology.length);
    }

    /**
     * Sets the learning rate. Must be greater than 0 and no more than 1.<br/>
     * Default value is 0.05.
     *
     * @param learningRate the learning rate
     * @throws IllegalArgumentException if passed value is invalid
     */
    public void setLearningRate(double learningRate) {
        if (learningRate <= 0 || learningRate > 1) {
            throw new IllegalArgumentException("The learning rate must be greater than 0 and no more than 1. "
                    + "Entered value is : " + learningRate);
        }
        this.learningRate = learningRate;
    }

    /**
     * Gets the learning rate. The default value is 0.05.
     * @return learning rate
     */
    public double getLearningRate() {
        return learningRate;
    }

    /**
     * Sets the regularization cost term for weights decay.
     * Must be greater than 0 and no more than 1.<br/>
     * Default value is 0.00001.
     *
     * @param weightsDecayCost the weights decay cost term
     * @throws IllegalArgumentException if passed value is invalid
     */
    public void setWeightsDecayRegularization(double weightsDecayCost) {
        if (weightsDecayCost <= 0 || weightsDecayCost > 1) {
            throw new IllegalArgumentException(
                    "The weights decay regularization cost " + "term must be greater than 0 and no more than 1. "
                            + "The passed  value is : " + weightsDecayCost);
        }
        this.weightsDecayCost = weightsDecayCost;
    }

    /**
     * Gets a value of the regularization cost term for weights decay.
     * @return regularization cost
     */
    public double getWeightsDecayRegularization() {
        return weightsDecayCost;
    }

    /**
     * Sets the number of training epochs. Must be greater than 0.<br/>
     * Default value is 100.
     *
     * @param epochs the number of training epochs
     * @throws IllegalArgumentException if passed value is invalid
     */
    public void setTrainingEpochs(int epochs) {
        if (epochs <= 0) {
            throw new IllegalArgumentException(
                    "The learning rate must be greater than zero. " + "Entered value is : " + epochs);
        }
        this.epochs = epochs;
    }

    /**
     * Gets number of training epochs.
     * Default value is 100.
     * @return training epochs
     */
    public int getTrainingEpochs() {
        return epochs;
    }

    /**
     * Sets whether attributes of instances data (except label attributes) should
     * be normalized prior to building the learner. Normalization is performed
     * on numeric attributes to the range {-1,1}).<br/>
     * When making prediction, attributes of passed input instance are also
     * normalized prior to making prediction.<br/>
     * Default is true (normalization of attributes takes place).
     *
     * @param normalize flag if normalization of attributes should be used
     * @throws IllegalArgumentException if passed value is invalid
     */
    public void setNormalizeAttributes(boolean normalize) {
        normalizeAttributes = normalize;
    }

    /**
     * Gets a value if normalization of nominal attributes should take place.
     * Default value is true.
     * @return a value if normalization of nominal attributes should take place
     */
    public boolean getNormalizeAttributes() {
        return normalizeAttributes;
    }

    protected void buildInternal(final MultiLabelInstances instances) throws Exception {

        // delete filter if available from previous build, a new one will be created if necessary
        nominalToBinaryFilter = null;

        MultiLabelInstances trainInstances = instances.clone();
        List<DataPair> trainData = prepareData(trainInstances);
        int inputsDim = trainData.get(0).getInput().length;
        model = buildNeuralNetwork(inputsDim);
        BPMLLAlgorithm learnAlg = new BPMLLAlgorithm(model, weightsDecayCost);

        int numInstances = trainData.size();
        int processedInstances = 0;
        double prevError = Double.MAX_VALUE;
        double error = 0;
        for (int epoch = 0; epoch < epochs; epoch++) {
            Collections.shuffle(trainData, new Random(1));
            for (int index = 0; index < numInstances; index++) {
                DataPair trainPair = trainData.get(index);
                double result = learnAlg.learn(trainPair.getInput(), trainPair.getOutput(), learningRate);
                if (!Double.isNaN(result)) {
                    error += result;
                    processedInstances++;
                }
            }

            if (getDebug()) {
                if (epoch % 10 == 0) {
                    debug("Training epoch : " + epoch + "  Model error : " + error / processedInstances);
                }
            }

            double errorDiff = prevError - error;
            if (errorDiff <= ERROR_SMALL_CHANGE * prevError) {
                if (getDebug()) {
                    debug("Global training error does not decrease enough. Training terminated.");
                }
                break;
            }
        }

        thresholdF = buildThresholdFunction(trainData);
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {

        TechnicalInformation technicalInfo = new TechnicalInformation(Type.ARTICLE);
        technicalInfo.setValue(Field.AUTHOR, "Zhang, M.L., Zhou, Z.H.");
        technicalInfo.setValue(Field.YEAR, "2006");
        technicalInfo.setValue(Field.TITLE,
                "Multi-label neural networks with applications to functional genomics and text categorization");
        technicalInfo.setValue(Field.JOURNAL, "IEEE Transactions on Knowledge and Data Engineering");
        technicalInfo.setValue(Field.VOLUME, "18");
        technicalInfo.setValue(Field.PAGES, "1338-1351");
        return technicalInfo;
    }

    private ThresholdFunction buildThresholdFunction(List<DataPair> trainData) {

        int numExamples = trainData.size();
        double[][] idealLabels = new double[numExamples][numLabels];
        double[][] modelConfidences = new double[numExamples][numLabels];

        for (int example = 0; example < numExamples; example++) {
            DataPair dataPair = trainData.get(example);
            idealLabels[example] = dataPair.getOutput();
            modelConfidences[example] = model.feedForward(dataPair.getInput());
        }

        return new ThresholdFunction(idealLabels, modelConfidences);
    }

    private NeuralNet buildNeuralNetwork(int inputsDim) {

        int[] networkTopology;
        if (hiddenLayersTopology == null) {
            int hiddenUnits = Math.round(0.2f * inputsDim);
            hiddenLayersTopology = new int[] { hiddenUnits };
            networkTopology = new int[] { inputsDim, hiddenUnits, numLabels };
        } else {
            networkTopology = new int[hiddenLayersTopology.length + 2];
            networkTopology[0] = inputsDim;
            System.arraycopy(hiddenLayersTopology, 0, networkTopology, 1, hiddenLayersTopology.length);
            networkTopology[networkTopology.length - 1] = numLabels;
        }

        NeuralNet aModel = new BasicNeuralNet(networkTopology, NET_BIAS, ActivationTANH.class,
                randomnessSeed == null ? null : new Random(randomnessSeed));

        return aModel;
    }

    /**
     * Prepares {@link MultiLabelInstances} data for the learning algorithm.
     * <br/>
     * The data are checked for correct format, label attributes
     * are converted to bipolar values. Finally {@link Instance} instances are
     * converted to {@link DataPair} instances, which will be used for the algorithm.
     */
    private List<DataPair> prepareData(MultiLabelInstances mlData) {

        Instances data = mlData.getDataSet();
        data = checkAttributesFormat(data, mlData.getFeatureAttributes());
        if (data == null) {
            throw new InvalidDataException("Attributes are not in correct format. "
                    + "Input attributes (all but the label attributes) must be nominal or numeric.");
        } else {
            try {
                mlData = mlData.reintegrateModifiedDataSet(data);
                this.labelIndices = mlData.getLabelIndices();
            } catch (InvalidDataFormatException e) {
                throw new InvalidDataException("Failed to create a multilabel data set from modified instances.");
            }

            if (normalizeAttributes) {
                normalizer = new NormalizationFilter(mlData, true, -0.8, 0.8);
            }

            return DataPair.createDataPairs(mlData, true);
        }
    }

    /**
     * Checks {@link Instances} data if attributes (all but the label attributes)
     * are numeric or nominal. Nominal attributes are transformed to binary by use of
     * {@link NominalToBinary} filter.
     *
     * @param dataSet instances data to be checked
     * @param inputAttributes input/feature attributes which format need to be checked
     * @return data set if it passed checks; otherwise <code>null</code>
     */
    private Instances checkAttributesFormat(Instances dataSet, Set<Attribute> inputAttributes) {

        StringBuilder nominalAttrRange = new StringBuilder();
        String rangeDelimiter = ",";
        for (Attribute attribute : inputAttributes) {
            if (attribute.isNumeric() == false) {
                if (attribute.isNominal()) {
                    nominalAttrRange.append((attribute.index() + 1) + rangeDelimiter);
                } else {
                    // fail check if any other attribute type than nominal or numeric is used
                    return null;
                }
            }
        }

        // convert any nominal attributes to binary
        if (nominalAttrRange.length() > 0) {
            nominalAttrRange.deleteCharAt(nominalAttrRange.lastIndexOf(rangeDelimiter));
            try {
                nominalToBinaryFilter = new NominalToBinary();
                nominalToBinaryFilter.setAttributeIndices(nominalAttrRange.toString());
                nominalToBinaryFilter.setInputFormat(dataSet);
                dataSet = Filter.useFilter(dataSet, nominalToBinaryFilter);
            } catch (Exception exception) {
                nominalToBinaryFilter = null;
                if (getDebug()) {
                    debug("Failed to apply NominalToBinary filter to the input instances data. " + "Error message: "
                            + exception.getMessage());
                }
                throw new WekaException("Failed to apply NominalToBinary filter to the input instances data.",
                        exception);
            }
        }

        return dataSet;
    }

    public MultiLabelOutput makePredictionInternal(Instance instance) throws InvalidDataException {

        Instance inputInstance = null;
        if (nominalToBinaryFilter != null) {
            try {
                nominalToBinaryFilter.input(instance);
                inputInstance = nominalToBinaryFilter.output();
                inputInstance.setDataset(null);
            } catch (Exception ex) {
                throw new InvalidDataException("The input instance for prediction is invalid. "
                        + "Instance is not consistent with the data the model was built for.");
            }
        } else {
            inputInstance = DataUtils.createInstance(instance, instance.weight(), instance.toDoubleArray());
        }

        int numAttributes = inputInstance.numAttributes();
        if (numAttributes < model.getNetInputSize()) {
            throw new InvalidDataException("Input instance do not have enough attributes "
                    + "to be processed by the model. Instance is not consistent with the data the model was built for.");
        }

        // if instance has more attributes than model input, we assume that true outputs
        // are there, so we remove them
        List<Integer> someLabelIndices = new ArrayList<Integer>();
        boolean labelsAreThere = false;
        if (numAttributes > model.getNetInputSize()) {
            for (int index : this.labelIndices) {
                someLabelIndices.add(index);
            }

            labelsAreThere = true;
        }

        if (normalizeAttributes) {
            normalizer.normalize(inputInstance);
        }

        int inputDim = model.getNetInputSize();
        double[] inputPattern = new double[inputDim];
        int indexCounter = 0;
        for (int attrIndex = 0; attrIndex < numAttributes; attrIndex++) {
            if (labelsAreThere && someLabelIndices.contains(attrIndex)) {
                continue;
            }
            inputPattern[indexCounter] = inputInstance.value(attrIndex);
            indexCounter++;
        }

        double[] labelConfidences = model.feedForward(inputPattern);
        double threshold = thresholdF.computeThreshold(labelConfidences);
        boolean[] labelPredictions = new boolean[numLabels];
        Arrays.fill(labelPredictions, false);

        for (int labelIndex = 0; labelIndex < numLabels; labelIndex++) {
            if (labelConfidences[labelIndex] > threshold) {
                labelPredictions[labelIndex] = true;
            }
            // translate from bipolar output to binary
            labelConfidences[labelIndex] = (labelConfidences[labelIndex] + 1) / 2;
        }

        MultiLabelOutput mlo = new MultiLabelOutput(labelPredictions, labelConfidences);
        return mlo;
    }
}