moa.tasks.EvaluatePeriodicHeldOutTestLA.java Source code

Java tutorial

Introduction

Here is the source code for moa.tasks.EvaluatePeriodicHeldOutTestLA.java

Source

/*
 *    EvaluatePeriodicHeldOutTest.java
 *    Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
 *    @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
 *    @author Ammar Shaker (shaker@mathematik.uni-marburg.de)
 *
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program. If not, see <http://www.gnu.org/licenses/>.
 *    
 */
package moa.tasks;

import java.io.File;
import java.io.FileOutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.List;

import moa.classifiers.Classifier;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.core.StringUtils;
import moa.core.TimingUtils;
import moa.evaluation.ClassificationPerformanceEvaluator;
import moa.evaluation.LearningCurve;
import moa.evaluation.LearningEvaluation;
import moa.options.ClassOption;
import moa.options.FileOption;
import moa.options.FlagOption;
import moa.options.IntOption;
import moa.streams.CachedInstancesStream;
import moa.streams.InstanceStream;
import weka.core.Instance;
import weka.core.Instances;

/**
 * Task for evaluating a classifier on a stream by periodically testing on a heldout set.
 *
 * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
 * @version $Revision: 7 $
 */
public class EvaluatePeriodicHeldOutTestLA extends MainTask {

    @Override
    public String getPurposeString() {
        return "Evaluates a classifier on a stream by periodically testing on a heldout set.";
    }

    private static final long serialVersionUID = 1L;

    public ClassOption learnerOption = new ClassOption("learner", 'l', "Classifier to train.", Classifier.class,
            "trees.HoeffdingTree");

    public ClassOption streamOption = new ClassOption("stream", 's', "Stream to learn from.", InstanceStream.class,
            "generators.RandomTreeGenerator");

    public ClassOption evaluatorOption = new ClassOption("evaluator", 'e',
            "Classification performance evaluation method.", ClassificationPerformanceEvaluator.class,
            "BasicClassificationPerformanceEvaluator");

    public IntOption testSizeOption = new IntOption("testSize", 'n', "Number of testing examples.", 1000000, 0,
            Integer.MAX_VALUE);

    public IntOption trainSizeOption = new IntOption("trainSize", 'i',
            "Number of training examples, <1 = unlimited.", 0, 0, Integer.MAX_VALUE);

    public IntOption trainTimeOption = new IntOption("trainTime", 't', "Number of training seconds.", 10 * 60 * 60,
            0, Integer.MAX_VALUE);

    public IntOption sampleFrequencyOption = new IntOption("sampleFrequency", 'f',
            "Number of training examples between samples of learning performance.", 100000, 0, Integer.MAX_VALUE);

    public FileOption dumpFileOption = new FileOption("dumpFile", 'd',
            "File to append intermediate csv results to.", null, "csv", true);

    public FlagOption cacheTestOption = new FlagOption("cacheTest", 'c', "Cache test instances in memory.");

    @Override
    protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
        int numberOfCheckCalls = 0;
        double modelTime = 0;
        Classifier learner = (Classifier) getPreparedClassOption(this.learnerOption);
        InstanceStream stream = (InstanceStream) getPreparedClassOption(this.streamOption);
        ClassificationPerformanceEvaluator evaluator = (ClassificationPerformanceEvaluator) getPreparedClassOption(
                this.evaluatorOption);
        learner.setModelContext(stream.getHeader());
        long instancesProcessed = 0;
        LearningCurve learningCurve = new LearningCurve("evaluation instances") {

            @Override
            public String entryToString(int entryIndex) {
                // TODO Auto-generated method stub
                return super.entryToString(entryIndex).replace(',', '\t');
            }

        };
        File dumpFile = this.dumpFileOption.getFile();
        PrintStream immediateResultStream = null;
        if (dumpFile != null) {
            try {
                if (dumpFile.exists()) {
                    immediateResultStream = new PrintStream(new FileOutputStream(dumpFile, true), true);
                } else {
                    immediateResultStream = new PrintStream(new FileOutputStream(dumpFile), true);
                }
            } catch (Exception ex) {
                throw new RuntimeException("Unable to open immediate result file: " + dumpFile, ex);
            }
        }
        boolean firstDump = true;
        InstanceStream testStream = null;
        int testSize = this.testSizeOption.getValue();
        if (this.cacheTestOption.isSet()) {
            monitor.setCurrentActivity("Caching test examples...", -1.0);
            Instances testInstances = new Instances(stream.getHeader(), this.testSizeOption.getValue());
            while (testInstances.numInstances() < testSize) {
                testInstances.add(stream.nextInstance());
                if (testInstances.numInstances() % INSTANCES_BETWEEN_MONITOR_UPDATES == 0) {
                    if (monitor.taskShouldAbort()) {
                        return null;
                    }
                    monitor.setCurrentActivityFractionComplete(
                            (double) testInstances.numInstances() / (double) (this.testSizeOption.getValue()));
                }
            }
            testStream = new CachedInstancesStream(testInstances);
        } else {
            //testStream = (InstanceStream) stream.copy();
            testStream = stream;
            /*monitor.setCurrentActivity("Skipping test examples...", -1.0);
            for (int i = 0; i < testSize; i++) {
            stream.nextInstance();
            }*/
        }
        instancesProcessed = 0;
        TimingUtils.enablePreciseTiming();
        double totalTrainTime = 0.0;
        while ((this.trainSizeOption.getValue() < 1 || instancesProcessed < this.trainSizeOption.getValue())
                && stream.hasMoreInstances() == true) {
            monitor.setCurrentActivityDescription("Training...");
            long instancesTarget = instancesProcessed + this.sampleFrequencyOption.getValue();
            long fixup = 0;
            long trainStartTime = System.nanoTime();
            while (instancesProcessed < instancesTarget && stream.hasMoreInstances() == true) {
                long start = System.nanoTime();
                Instance i = stream.nextInstance();
                fixup += System.nanoTime() - start;
                learner.trainOnInstance(i);
                instancesProcessed++;
                if (instancesProcessed % INSTANCES_BETWEEN_MONITOR_UPDATES == 0) {
                    if (monitor.taskShouldAbort()) {
                        return null;
                    }
                    monitor.setCurrentActivityFractionComplete(
                            (double) (instancesProcessed) / (double) (this.trainSizeOption.getValue()));
                }
            }
            double lastTrainTime = TimingUtils.nanoTimeToSeconds(System.nanoTime() - trainStartTime);
            totalTrainTime += lastTrainTime;
            if (totalTrainTime > this.trainTimeOption.getValue()) {
                break;
            }
            if (this.cacheTestOption.isSet()) {
                testStream.restart();
            }
            evaluator.reset();
            long testInstancesProcessed = 0;
            monitor.setCurrentActivityDescription("Testing (after "
                    + StringUtils.doubleToString(
                            ((double) (instancesProcessed) / (double) (this.trainSizeOption.getValue()) * 100.0), 2)
                    + "% training)...");
            long testStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread();
            int instCount = 0;
            long count = 0;
            long period = 0;

            for (instCount = 0; instCount < testSize; instCount++) {
                if (stream.hasMoreInstances() == false) {
                    break;
                }
                Instance testInst = (Instance) testStream.nextInstance().copy();
                if (instCount == 0) {
                    long modelBuildStart = System.nanoTime();
                    learner.getVotesForInstance(testInst);
                    long modelBuildEnd = System.nanoTime();
                    modelTime = (modelBuildEnd - modelBuildStart) / 1000000.0;
                }
                double trueClass = testInst.classValue();
                testInst.setClassMissing();

                double[] prediction = learner.getVotesForInstance(testInst);

                do // tune to get more or less stable data
                {
                    long start = System.nanoTime();
                    for (int i = 0; i < numberOfCheckCalls; ++i)
                        prediction = learner.getVotesForInstance(testInst);
                    long end = System.nanoTime();

                    if ((end - start) / 1000000.0 < 30) {
                        numberOfCheckCalls += 10;
                    } else {
                        count += numberOfCheckCalls;
                        period += (end - start);
                        break;
                    }
                } while (true);

                testInst.setClassValue(trueClass);
                evaluator.addResult(testInst, prediction);
                testInstancesProcessed++;
                if (testInstancesProcessed % INSTANCES_BETWEEN_MONITOR_UPDATES == 0) {
                    if (monitor.taskShouldAbort()) {
                        return null;
                    }
                    monitor.setCurrentActivityFractionComplete(
                            (double) testInstancesProcessed / (double) (testSize));
                }
            }
            if (instCount != testSize) {
                break;
            }
            double final_period = ((double) period / count) / 1000000.0; // msec
            double testTime = TimingUtils
                    .nanoTimeToSeconds(TimingUtils.getNanoCPUTimeOfCurrentThread() - testStartTime);
            List<Measurement> measurements = new ArrayList<Measurement>();
            measurements.add(new Measurement("number of attributes", testStream.getHeader().numAttributes() - 1));
            measurements.add(new Measurement("evaluation instances", instancesProcessed));
            measurements.add(new Measurement("test latency", final_period));
            measurements.add(new Measurement("model build time", modelTime));
            measurements.add(new Measurement("trainTime", (totalTrainTime / instancesProcessed)));
            Measurement[] performanceMeasurements = evaluator.getPerformanceMeasurements();
            for (Measurement measurement : performanceMeasurements) {
                measurements.add(measurement);
            }
            Measurement[] modelMeasurements = learner.getModelMeasurements();
            for (Measurement measurement : modelMeasurements) {
                measurements.add(measurement);
            }
            learningCurve.insertEntry(
                    new LearningEvaluation(measurements.toArray(new Measurement[measurements.size()])));
            if (immediateResultStream != null) {
                if (firstDump) {
                    immediateResultStream.println(learningCurve.headerToString());
                    firstDump = false;
                }
                immediateResultStream.println(learningCurve.entryToString(learningCurve.numEntries() - 1));
                immediateResultStream.flush();
            }
            //if (monitor.resultPreviewRequested()) {
            //   monitor.setLatestResultPreview(learningCurve.copy());
            //}

        }
        if (immediateResultStream != null) {
            immediateResultStream.close();
        }
        return learningCurve;
    }

    @Override
    public Class<?> getTaskResultType() {
        return LearningCurve.class;
    }
}