Java tutorial
/* * OzaBoost.java * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @author Luke Barnett (luke@barnett.net.nz) * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. * */ package distributedRedditAnalyser; import java.util.ArrayDeque; import java.util.Random; import java.util.concurrent.Semaphore; import weka.core.Instance; import moa.classifiers.AbstractClassifier; import moa.classifiers.Classifier; import moa.core.DoubleVector; import moa.core.Measurement; import moa.core.MiscUtils; import moa.options.ClassOption; import moa.options.FlagOption; import moa.options.IntOption; /** * Rewrite of the moa implementation of OzaBoost to accommodate a distributed setting and the sharing of classifiers * * Keeps the latest K classifiers * * Largely derrived from the moa implementation: * http://code.google.com/p/moa/source/browse/moa/src/main/java/moa/classifiers/meta/OzaBoost.java * * @author Luke Barnett 1109967 * @author Tony Chen 1111377 * */ public class OzaBoost extends AbstractClassifier { private static final long serialVersionUID = -4456874021287021340L; private Semaphore lock = new Semaphore(1); @Override public String getPurposeString() { return "Incremental on-line boosting of Oza and Russell."; } public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "trees.HoeffdingTree"); public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The max number of models to boost.", 10, 1, Integer.MAX_VALUE); public FlagOption pureBoostOption = new FlagOption("pureBoost", 'p', "Boost with weights only; no poisson."); protected ArrayDeque<ClassifierInstance> ensemble; @Override public void resetLearningImpl() { try { lock.acquire(); this.ensemble = new ArrayDeque<ClassifierInstance>(ensembleSizeOption.getValue()); Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption); baseLearner.resetLearning(); } catch (InterruptedException e) { e.printStackTrace(); } finally { lock.release(); } } @Override public void trainOnInstanceImpl(Instance inst) { try { lock.acquire(); //Get a new classifier Classifier newClassifier = ((Classifier) getPreparedClassOption(this.baseLearnerOption)).copy(); ensemble.add(new ClassifierInstance(newClassifier)); //If we have too many classifiers while (ensemble.size() > ensembleSizeOption.getValue()) ensemble.pollFirst(); double lambda_d = 1.0; for (ClassifierInstance c : ensemble) { double k = this.pureBoostOption.isSet() ? lambda_d : MiscUtils.poisson(lambda_d, this.classifierRandom); if (k > 0.0) { Instance weightedInst = (Instance) inst.copy(); weightedInst.setWeight(inst.weight() * k); c.getClassifier().trainOnInstance(weightedInst); } if (c.getClassifier().correctlyClassifies(inst)) { c.setScms(c.getScms() + lambda_d); lambda_d *= this.trainingWeightSeenByModel / (2 * c.getScms()); } else { c.setSwms(c.getSwms() + lambda_d); lambda_d *= this.trainingWeightSeenByModel / (2 * c.getSwms()); } } } catch (InterruptedException e) { e.printStackTrace(); } finally { lock.release(); } } protected double getEnsembleMemberWeight(ClassifierInstance i) { double em = i.getSwms() / (i.getScms() + i.getSwms()); if ((em == 0.0) || (em > 0.5)) { return 0.0; } double Bm = em / (1.0 - em); return Math.log(1.0 / Bm); } public double[] getVotesForInstance(Instance inst) { DoubleVector combinedVote = new DoubleVector(); try { lock.acquire(); for (ClassifierInstance c : ensemble) { double memberWeight = getEnsembleMemberWeight(c); if (memberWeight > 0.0) { DoubleVector vote = new DoubleVector(c.getClassifier().getVotesForInstance(inst)); if (vote.sumOfValues() > 0.0) { vote.normalize(); vote.scaleValues(memberWeight); combinedVote.addValues(vote); } } else { break; } } } catch (InterruptedException e) { e.printStackTrace(); } finally { lock.release(); } return combinedVote.getArrayRef(); } public boolean isRandomizable() { return true; } @Override public void getModelDescription(StringBuilder out, int indent) { } @Override protected Measurement[] getModelMeasurementsImpl() { return new Measurement[] { new Measurement("ensemble size", this.ensemble != null ? this.ensemble.size() : 0) }; } @Override public Classifier[] getSubClassifiers() { Classifier[] classifiers = new Classifier[ensemble.size()]; try { lock.acquire(); int i = 0; for (ClassifierInstance c : ensemble) { if (i < classifiers.length) { classifiers[i] = c.getClassifier().copy(); } else { break; } i++; } } catch (InterruptedException e) { e.printStackTrace(); } finally { lock.release(); } return classifiers; } public ClassifierInstance getLatestClassifier() { return ensemble.peekLast(); } public void addClassifier(ClassifierInstance c) { try { lock.acquire(); ensemble.add(c.clone()); //If we have too many classifiers while (ensemble.size() > ensembleSizeOption.getValue()) ensemble.pollFirst(); } catch (InterruptedException e) { e.printStackTrace(); } finally { lock.release(); } } }