tr.gov.ulakbim.jDenetX.classifiers.AbstractClassifier.java Source code

Java tutorial

Introduction

Here is the source code for tr.gov.ulakbim.jDenetX.classifiers.AbstractClassifier.java

Source

/*
 *    AbstractClassifier.java
 *    Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
 *    @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
 *
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */
package tr.gov.ulakbim.jDenetX.classifiers;

import tr.gov.ulakbim.jDenetX.core.InstancesHeader;
import tr.gov.ulakbim.jDenetX.core.Measurement;
import tr.gov.ulakbim.jDenetX.core.ObjectRepository;
import tr.gov.ulakbim.jDenetX.core.StringUtils;
import tr.gov.ulakbim.jDenetX.gui.AWTRenderer;
import tr.gov.ulakbim.jDenetX.options.AbstractOptionHandler;
import tr.gov.ulakbim.jDenetX.options.IntOption;
import tr.gov.ulakbim.jDenetX.tasks.TaskMonitor;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;

import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;

public abstract class AbstractClassifier extends AbstractOptionHandler implements Classifier {

    @Override
    public String getPurposeString() {
        return "MOA Classifier: " + getClass().getCanonicalName();
    }

    protected InstancesHeader modelContext;

    protected double trainingWeightSeenByModel = 0.0;

    protected int randomSeed = 1;

    protected IntOption randomSeedOption;

    protected Random classifierRandom;

    public AbstractClassifier() {
        if (isRandomizable()) {
            this.randomSeedOption = new IntOption("randomSeed", 'r', "Seed for random behaviour of the classifier.",
                    1);
        }
    }

    @Override
    public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
        if (this.randomSeedOption != null) {
            this.randomSeed = this.randomSeedOption.getValue();
        }
        if (!trainingHasStarted()) {
            resetLearning();
        }
    }

    public void setModelContext(InstancesHeader ih) {
        if ((ih != null) && (ih.classIndex() < 0)) {
            throw new IllegalArgumentException("Context for a classifier must include a class to learn");
        }
        if (trainingHasStarted() && (this.modelContext != null)
                && ((ih == null) || !contextIsCompatible(this.modelContext, ih))) {
            throw new IllegalArgumentException("New context is not compatible with existing model");
        }
        this.modelContext = ih;
    }

    public InstancesHeader getModelContext() {
        return this.modelContext;
    }

    public void setRandomSeed(int s) {
        this.randomSeed = s;
        if (this.randomSeedOption != null) {
            // keep option consistent
            this.randomSeedOption.setValue(s);
        }
    }

    public boolean trainingHasStarted() {
        return this.trainingWeightSeenByModel > 0.0;
    }

    public double trainingWeightSeenByModel() {
        return this.trainingWeightSeenByModel;
    }

    public void resetLearning() {
        this.trainingWeightSeenByModel = 0.0;
        if (isRandomizable()) {
            this.classifierRandom = new Random(this.randomSeed);
        }
        resetLearningImpl();
    }

    public void trainOnInstance(Instance inst) {
        if (inst.weight() > 0.0) {
            this.trainingWeightSeenByModel += inst.weight();
            trainOnInstanceImpl(inst);
        }
    }

    public Measurement[] getModelMeasurements() {
        List<Measurement> measurementList = new LinkedList<Measurement>();
        measurementList.add(new Measurement("model training instances", trainingWeightSeenByModel()));
        measurementList.add(new Measurement("model serialized size (bytes)", measureByteSize()));
        Measurement[] modelMeasurements = getModelMeasurementsImpl();
        if (modelMeasurements != null) {
            Collections.addAll(measurementList, modelMeasurements);
        }
        // add average of sub-model measurements
        Classifier[] subModels = getSubClassifiers();
        if ((subModels != null) && (subModels.length > 0)) {
            List<Measurement[]> subMeasurements = new LinkedList<Measurement[]>();
            for (Classifier subModel : subModels) {
                if (subModel != null) {
                    subMeasurements.add(subModel.getModelMeasurements());
                }
            }
            Measurement[] avgMeasurements = Measurement
                    .averageMeasurements(subMeasurements.toArray(new Measurement[subMeasurements.size()][]));
            Collections.addAll(measurementList, avgMeasurements);
        }
        return measurementList.toArray(new Measurement[measurementList.size()]);
    }

    public void getDescription(StringBuilder out, int indent) {
        StringUtils.appendIndented(out, indent, "Model type: ");
        out.append(this.getClass().getName());
        StringUtils.appendNewline(out);
        Measurement.getMeasurementsDescription(getModelMeasurements(), out, indent);
        StringUtils.appendNewlineIndented(out, indent, "Model description:");
        StringUtils.appendNewline(out);
        if (trainingHasStarted()) {
            getModelDescription(out, indent);
        } else {
            StringUtils.appendIndented(out, indent, "Model has not been trained.");
        }
    }

    public Classifier[] getSubClassifiers() {
        Classifier[] classifier = null;
        return classifier;
    }

    @Override
    public Classifier copy() {
        return (Classifier) super.copy();
    }

    public boolean correctlyClassifies(Instance inst) {
        return Utils.maxIndex(getVotesForInstance(inst)) == (int) inst.classValue();
    }

    public String getClassNameString() {
        return InstancesHeader.getClassNameString(this.modelContext);
    }

    public String getClassLabelString(int classLabelIndex) {
        return InstancesHeader.getClassLabelString(this.modelContext, classLabelIndex);
    }

    public String getAttributeNameString(int attIndex) {
        return InstancesHeader.getAttributeNameString(this.modelContext, attIndex);
    }

    public String getNominalValueString(int attIndex, int valIndex) {
        return InstancesHeader.getNominalValueString(this.modelContext, attIndex, valIndex);
    }

    // originalContext not null
    // newContext not null
    public static boolean contextIsCompatible(InstancesHeader originalContext, InstancesHeader newContext) {
        // rule 1: num classes can increase but never decrease
        // rule 2: num attributes can increase but never decrease
        // rule 3: num nominal attribute values can increase but never decrease
        // rule 4: attribute types must stay in the same order (although class
        // can
        // move; is always skipped over)
        // attribute names are free to change, but should always still represent
        // the original attributes
        if (newContext.numClasses() < originalContext.numClasses()) {
            return false; // rule 1
        }
        if (newContext.numAttributes() < originalContext.numAttributes()) {
            return false; // rule 2
        }
        int oPos = 0;
        int nPos = 0;
        while (oPos < originalContext.numAttributes()) {
            if (oPos == originalContext.classIndex()) {
                oPos++;
                if (!(oPos < originalContext.numAttributes())) {
                    break;
                }
            }
            if (nPos == newContext.classIndex()) {
                nPos++;
            }
            if (originalContext.attribute(oPos).isNominal()) {
                if (!newContext.attribute(nPos).isNominal()) {
                    return false; // rule 4
                }
                if (newContext.attribute(nPos).numValues() < originalContext.attribute(oPos).numValues()) {
                    return false; // rule 3
                }
            } else {
                assert (originalContext.attribute(oPos).isNumeric());
                if (!newContext.attribute(nPos).isNumeric()) {
                    return false; // rule 4
                }
            }
            oPos++;
            nPos++;
        }
        return true; // all checks clear
    }

    public AWTRenderer getAWTRenderer() {
        // TODO should return a default renderer here
        // - or should null be interpreted as the default?
        return null;
    }

    // reason for ...Impl methods:
    // ease programmer burden by not requiring them to remember calls to super
    // in overridden methods & will produce compiler errors if not overridden
    public abstract void resetLearningImpl();

    public abstract void trainOnInstanceImpl(Instance inst);

    protected abstract Measurement[] getModelMeasurementsImpl();

    public abstract void getModelDescription(StringBuilder out, int indent);

    protected static int modelAttIndexToInstanceAttIndex(int index, Instance inst) {
        return inst.classIndex() > index ? index : index + 1;
    }

    protected static int modelAttIndexToInstanceAttIndex(int index, Instances insts) {
        return insts.classIndex() > index ? index : index + 1;
    }
}