moa.classifiers.meta.OnlineAccuracyUpdatedEnsemble.java Source code

Java tutorial

Introduction

Here is the source code for moa.classifiers.meta.OnlineAccuracyUpdatedEnsemble.java

Source

/*
 *    OnlineAccuracyUpdatedEnsemble.java
 *    Copyright (C) 2013 Poznan University of Technology, Poznan, Poland
 *    @author Dariusz Brzezinski (dariusz.brzezinski@cs.put.poznan.pl)
 *
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program. If not, see <http://www.gnu.org/licenses/>.
 *    
 */
package moa.classifiers.meta;

import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.trees.HoeffdingTree;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.options.ClassOption;
import moa.options.FlagOption;
import moa.options.FloatOption;
import moa.options.IntOption;
import moa.tasks.TaskMonitor;
import weka.core.Instance;

/**
 * The online version of the Accuracy Updated Ensemble as proposed by
 * Brzezinski and Stefanowski in "Combining block-based and online methods 
 * in learning ensembles from concept drifting data streams", Information Sciences, 2014.
 */
public class OnlineAccuracyUpdatedEnsemble extends AbstractClassifier {

    private static final long serialVersionUID = 1L;

    /**
     * Type of classifier to use as a component classifier.
     */
    public ClassOption learnerOption = new ClassOption("learner", 'l', "Classifier to train.", Classifier.class,
            "trees.HoeffdingTree -e 2000000 -g 100 -c 0.01");

    /**
     * Number of component classifiers.
     */
    public IntOption memberCountOption = new IntOption("memberCount", 'n',
            "The maximum number of classifiers in an ensemble.", 10, 1, Integer.MAX_VALUE);

    /**
     * Chunk size.
     */
    public FloatOption windowSizeOption = new FloatOption("windowSize", 'w',
            "The window size used for classifier creation and evaluation.", 500, 1, Integer.MAX_VALUE);

    /**
     * Determines the maximum size of model (evaluated after every chunk).
     */
    public IntOption maxByteSizeOption = new IntOption("maxByteSize", 'm', "Maximum memory consumed by ensemble.",
            33554432, 0, Integer.MAX_VALUE);

    /**
     * Determines whether additional information should be sent to the output.
     */
    public FlagOption verboseOption = new FlagOption("verbose", 'v',
            "When checked the algorithm outputs additional information about component classifier weights.");

    /**
     * Determines whether additional information should be sent to the output.
     */
    public FlagOption linearOption = new FlagOption("linearFunction", 'f',
            "When checked the algorithm uses a linear weighting function.");

    /**
     * The weights of stored classifiers. 
     * weights[x][0] = weight
     * weights[x][1] = classifier number in learners
     */
    protected double[][] weights;

    /**
     * Class distributions.
     */
    protected long[] classDistributions;

    /**
     * Ensemble classifiers.
     */
    protected ClassifierWithMemory[] ensemble;

    /**
     * Number of processed examples.
     */
    protected int processedInstances;

    /**
     * Candidate classifier.
     */
    protected ClassifierWithMemory candidate;

    /**
     * Current window of instance class values.
     */
    protected int[] currentWindow;

    /**
     * The mean square residual in a given moment, based on a window of latest examples.
     */
    protected double mse_r = 0;

    /**
     * Window size.
     */
    protected int windowSize = 0;

    @Override
    public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
        this.windowSize = (int) this.windowSizeOption.getValue();
        this.candidate = new ClassifierWithMemory(((Classifier) getPreparedClassOption(this.learnerOption)).copy(),
                this.windowSize);
        this.candidate.classifier.resetLearning();

        super.prepareForUseImpl(monitor, repository);
    }

    @Override
    public void resetLearningImpl() {
        this.currentWindow = null;
        this.windowSize = (int) this.windowSizeOption.getValue();
        this.classDistributions = null;
        this.processedInstances = 0;
        this.ensemble = new ClassifierWithMemory[0];

        this.candidate = new ClassifierWithMemory(((Classifier) getPreparedClassOption(this.learnerOption)).copy(),
                this.windowSize);
        this.candidate.classifier.resetLearning();
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        this.initVariables();

        if (this.processedInstances < this.windowSize) {
            this.classDistributions[(int) inst.classValue()]++;
        } else {
            this.classDistributions[this.currentWindow[processedInstances % this.windowSize]]--;
            this.classDistributions[(int) inst.classValue()]++;
        }

        this.currentWindow[processedInstances % this.windowSize] = (int) inst.classValue();
        this.processedInstances++;
        this.computeMseR();

        if (this.processedInstances % this.windowSize == 0) {
            this.createNewClassifier(inst);
        } else {
            this.candidate.classifier.trainOnInstance(inst);

            for (int i = 0; i < this.ensemble.length; i++) {
                this.weights[i][0] = this.computeWeight(i, inst);
            }
        }

        for (int i = 0; i < this.ensemble.length; i++) {
            this.ensemble[i].classifier.trainOnInstance(inst);
        }
    }

    /**
     * Determines whether the classifier is randomizable.
     */
    public boolean isRandomizable() {
        return false;
    }

    /**
     * Predicts a class for an example.
     */
    public double[] getVotesForInstance(Instance inst) {
        DoubleVector combinedVote = new DoubleVector();

        if (this.trainingWeightSeenByModel > 0.0) {
            for (int i = 0; i < this.ensemble.length; i++) {
                if (this.weights[i][0] > 0.0) {
                    DoubleVector vote = new DoubleVector(
                            this.ensemble[(int) this.weights[i][1]].classifier.getVotesForInstance(inst));

                    if (vote.sumOfValues() > 0.0) {
                        vote.normalize();
                        // scale weight and prevent overflow
                        vote.scaleValues(this.weights[i][0] / (1.0 * this.ensemble.length + 1.0));
                        combinedVote.addValues(vote);
                    }
                }
            }
        }

        //combinedVote.normalize();
        return combinedVote.getArrayRef();
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
    }

    @Override
    public Classifier[] getSubClassifiers() {
        Classifier[] subClassifiers = new Classifier[this.ensemble.length];

        for (int i = 0; i < this.ensemble.length; i++) {
            subClassifiers[i] = this.ensemble[i].classifier;
        }

        return subClassifiers;
    }

    /**
      * Processes a chunk.
      *
      * @param inst New example
      */
    protected void createNewClassifier(Instance inst) {
        // Compute weights
        double candidateClassifierWeight = 1.0 / (this.mse_r + Double.MIN_VALUE);

        if (linearOption.isSet()) {
            candidateClassifierWeight = java.lang.Math.max(this.mse_r, Double.MIN_VALUE);
        }

        for (int i = 0; i < this.ensemble.length; i++) {
            this.weights[i][0] = this.computeWeight(i, inst);
        }

        // Configure candidate classifier
        // The candidate classifier has been trained on the last window of examples
        this.candidate.birthday = this.processedInstances;

        if (this.ensemble.length < this.memberCountOption.getValue()) {
            // Add candidate classifier      
            this.addToStored(this.candidate, candidateClassifierWeight);
        } else {
            // Substitute poorest classifier
            int poorestClassifier = this.getPoorestClassifierIndex();

            if (this.weights[poorestClassifier][0] < candidateClassifierWeight) {
                this.weights[poorestClassifier][0] = candidateClassifierWeight;
                this.candidate.classifier = this.candidate.classifier;
                this.ensemble[(int) this.weights[poorestClassifier][1]] = this.candidate;
            }
        }

        this.candidate = new ClassifierWithMemory(((Classifier) getPreparedClassOption(this.learnerOption)).copy(),
                this.windowSize);
        this.candidate.classifier.resetLearning();

        this.enforceMemoryLimit();
    }

    /**
     * Checks if the memory limit is exceeded and if so prunes the classifiers in the ensemble.
     */
    protected void enforceMemoryLimit() {
        double memoryLimit = this.maxByteSizeOption.getValue() / (double) (this.ensemble.length + 1);

        for (int i = 0; i < this.ensemble.length; i++) {
            ((HoeffdingTree) this.ensemble[(int) this.weights[i][1]].classifier).maxByteSizeOption
                    .setValue((int) Math.round(memoryLimit));
            ((HoeffdingTree) this.ensemble[(int) this.weights[i][1]].classifier).enforceTrackerLimit();
        }
    }

    /**
     * Computes the MSEr threshold.
     * 
     * @return The MSEr threshold.
     */
    protected void computeMseR() {
        double p_c;
        this.mse_r = 0;

        for (int i = 0; i < this.classDistributions.length; i++) {
            p_c = (double) this.classDistributions[i] / (double) this.windowSize;
            this.mse_r += p_c * ((1 - p_c) * (1 - p_c));
        }
    }

    /**
     * Computes the weight of a learner before training a given example.
     * @param i the identifier (in terms of array learners) 
     * of the classifier for which the weight is supposed to be computed
     * @param example the newest example
     * @return the computed weight.
     */
    protected double computeWeight(int i, Instance example) {

        int d = this.windowSize;
        int t = this.processedInstances - this.ensemble[i].birthday;

        double e_it = 0;
        double mse_it = 0;
        double voteSum = 0;

        try {
            double[] votes = this.ensemble[i].classifier.getVotesForInstance(example);

            for (double element : votes) {
                voteSum += element;
            }

            if (voteSum > 0) {
                double f_it = 1 - (votes[(int) example.classValue()] / voteSum);
                e_it = f_it * f_it;
            } else {
                e_it = 1;
            }
        } catch (Exception e) {
            e_it = 1;
        }

        if (t > d) {
            mse_it = this.ensemble[i].mse_it + e_it / (double) d
                    - this.ensemble[i].squareErrors[t % d] / (double) d;
        } else {
            mse_it = this.ensemble[i].mse_it * (t - 1) / t + e_it / (double) t;
        }

        this.ensemble[i].squareErrors[t % d] = e_it;
        this.ensemble[i].mse_it = mse_it;

        if (linearOption.isSet()) {
            return java.lang.Math.max(mse_r - mse_it, Double.MIN_VALUE);
        } else {
            return 1.0 / (this.mse_r + mse_it + Double.MIN_VALUE);
        }
    }

    /**
     * Adds ensemble weights to the measurements.
     */
    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        if (this.verboseOption.isSet()) {
            Measurement[] measurements = new Measurement[(int) this.memberCountOption.getValue()];

            for (int m = 0; m < this.memberCountOption.getValue(); m++) {
                measurements[m] = new Measurement("Member weight " + (m + 1), -1);
            }

            if (this.weights != null) {
                for (int i = 0; i < this.weights.length; i++) {
                    measurements[i] = new Measurement("Member weight " + (i + 1), this.weights[i][0]);
                }
            }

            return measurements;
        } else {
            return null;
        }
    }

    /**
     * Adds a classifier to the storage.
     * 
     * @param newClassifier
     *            The classifier to add.
     * @param newClassifiersWeight
     *            The new classifiers weight.
     */
    protected void addToStored(ClassifierWithMemory newClassifier, double newClassifiersWeight) {
        ClassifierWithMemory[] newStored = new ClassifierWithMemory[this.ensemble.length + 1];
        double[][] newStoredWeights = new double[newStored.length][2];

        for (int i = 0; i < newStored.length; i++) {
            if (i < this.ensemble.length) {
                newStored[i] = this.ensemble[i];
                newStoredWeights[i][0] = this.weights[i][0];
                newStoredWeights[i][1] = this.weights[i][1];
            } else {
                newStored[i] = newClassifier;
                newStoredWeights[i][0] = newClassifiersWeight;
                newStoredWeights[i][1] = i;
            }
        }
        this.ensemble = newStored;
        this.weights = newStoredWeights;
    }

    /**
     * Finds the index of the classifier with the smallest weight.
     * @return
     */
    private int getPoorestClassifierIndex() {
        int minIndex = 0;

        for (int i = 1; i < this.weights.length; i++) {
            if (this.weights[i][0] < this.weights[minIndex][0]) {
                minIndex = i;
            }
        }

        return minIndex;
    }

    /**
     * Initiates the current chunk and class distribution variables.
     */
    private void initVariables() {
        if (this.currentWindow == null) {
            this.currentWindow = new int[this.windowSize];
        }

        if (this.classDistributions == null) {
            this.classDistributions = new long[this.getModelContext().classAttribute().numValues()];
        }
    }

    protected class ClassifierWithMemory {
        private Classifier classifier;
        private int birthday;
        private double[] squareErrors;
        private double mse_it;

        protected ClassifierWithMemory(Classifier classifier, int windowSize) {
            this.classifier = classifier;
            this.squareErrors = new double[windowSize];
            this.mse_it = 0;
        }
    }
}