distributedRedditAnalyser.OzaBoost.java Source code

Java tutorial

Introduction

Here is the source code for distributedRedditAnalyser.OzaBoost.java

Source

/*
 *    OzaBoost.java
 *    @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
 *    @author Luke Barnett (luke@barnett.net.nz)
 *
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program. If not, see <http://www.gnu.org/licenses/>.
 *    
 */
package distributedRedditAnalyser;

import java.util.ArrayDeque;
import java.util.Random;
import java.util.concurrent.Semaphore;

import weka.core.Instance;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.options.ClassOption;
import moa.options.FlagOption;
import moa.options.IntOption;

/**
 * Rewrite of the moa implementation of OzaBoost to accommodate a distributed setting and the sharing of classifiers
 * 
 * Keeps the latest K classifiers
 * 
 * Largely derrived from the moa implementation:
 * http://code.google.com/p/moa/source/browse/moa/src/main/java/moa/classifiers/meta/OzaBoost.java
 * 
 * @author Luke Barnett 1109967
 * @author Tony Chen 1111377
 *
 */
public class OzaBoost extends AbstractClassifier {
    private static final long serialVersionUID = -4456874021287021340L;

    private Semaphore lock = new Semaphore(1);

    @Override
    public String getPurposeString() {
        return "Incremental on-line boosting of Oza and Russell.";
    }

    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.",
            Classifier.class, "trees.HoeffdingTree");

    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The max number of models to boost.",
            10, 1, Integer.MAX_VALUE);

    public FlagOption pureBoostOption = new FlagOption("pureBoost", 'p', "Boost with weights only; no poisson.");

    protected ArrayDeque<ClassifierInstance> ensemble;

    @Override
    public void resetLearningImpl() {
        try {
            lock.acquire();
            this.ensemble = new ArrayDeque<ClassifierInstance>(ensembleSizeOption.getValue());
            Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
            baseLearner.resetLearning();
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            lock.release();
        }
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        try {
            lock.acquire();
            //Get a new classifier
            Classifier newClassifier = ((Classifier) getPreparedClassOption(this.baseLearnerOption)).copy();
            ensemble.add(new ClassifierInstance(newClassifier));

            //If we have too many classifiers
            while (ensemble.size() > ensembleSizeOption.getValue())
                ensemble.pollFirst();

            double lambda_d = 1.0;
            for (ClassifierInstance c : ensemble) {
                double k = this.pureBoostOption.isSet() ? lambda_d
                        : MiscUtils.poisson(lambda_d, this.classifierRandom);
                if (k > 0.0) {
                    Instance weightedInst = (Instance) inst.copy();
                    weightedInst.setWeight(inst.weight() * k);
                    c.getClassifier().trainOnInstance(weightedInst);
                }
                if (c.getClassifier().correctlyClassifies(inst)) {
                    c.setScms(c.getScms() + lambda_d);
                    lambda_d *= this.trainingWeightSeenByModel / (2 * c.getScms());
                } else {
                    c.setSwms(c.getSwms() + lambda_d);
                    lambda_d *= this.trainingWeightSeenByModel / (2 * c.getSwms());
                }
            }
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            lock.release();
        }
    }

    protected double getEnsembleMemberWeight(ClassifierInstance i) {
        double em = i.getSwms() / (i.getScms() + i.getSwms());
        if ((em == 0.0) || (em > 0.5)) {
            return 0.0;
        }
        double Bm = em / (1.0 - em);
        return Math.log(1.0 / Bm);
    }

    public double[] getVotesForInstance(Instance inst) {
        DoubleVector combinedVote = new DoubleVector();
        try {
            lock.acquire();
            for (ClassifierInstance c : ensemble) {
                double memberWeight = getEnsembleMemberWeight(c);
                if (memberWeight > 0.0) {
                    DoubleVector vote = new DoubleVector(c.getClassifier().getVotesForInstance(inst));
                    if (vote.sumOfValues() > 0.0) {
                        vote.normalize();
                        vote.scaleValues(memberWeight);
                        combinedVote.addValues(vote);
                    }
                } else {
                    break;
                }
            }
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            lock.release();
        }
        return combinedVote.getArrayRef();
    }

    public boolean isRandomizable() {
        return true;
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return new Measurement[] {
                new Measurement("ensemble size", this.ensemble != null ? this.ensemble.size() : 0) };
    }

    @Override
    public Classifier[] getSubClassifiers() {
        Classifier[] classifiers = new Classifier[ensemble.size()];
        try {
            lock.acquire();
            int i = 0;
            for (ClassifierInstance c : ensemble) {
                if (i < classifiers.length) {
                    classifiers[i] = c.getClassifier().copy();
                } else {
                    break;
                }
                i++;
            }
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            lock.release();
        }
        return classifiers;
    }

    public ClassifierInstance getLatestClassifier() {
        return ensemble.peekLast();
    }

    public void addClassifier(ClassifierInstance c) {
        try {
            lock.acquire();
            ensemble.add(c.clone());

            //If we have too many classifiers
            while (ensemble.size() > ensembleSizeOption.getValue())
                ensemble.pollFirst();

        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            lock.release();
        }
    }
}