edu.illinois.cs.cogcomp.lbjava.learn.WekaWrapper.java Source code

Java tutorial

Introduction

Here is the source code for edu.illinois.cs.cogcomp.lbjava.learn.WekaWrapper.java

Source

/**
 * This software is released under the University of Illinois/Research and Academic Use License. See
 * the LICENSE file in the root folder for details. Copyright (c) 2016
 *
 * Developed by: The Cognitive Computations Group, University of Illinois at Urbana-Champaign
 * http://cogcomp.cs.illinois.edu/
 */
package edu.illinois.cs.cogcomp.lbjava.learn;

import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.util.Enumeration;

import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import edu.illinois.cs.cogcomp.core.datastructures.vectors.ExceptionlessInputStream;
import edu.illinois.cs.cogcomp.core.datastructures.vectors.ExceptionlessOutputStream;
import edu.illinois.cs.cogcomp.lbjava.classify.Classifier;
import edu.illinois.cs.cogcomp.lbjava.classify.DiscretePrimitiveStringFeature;
import edu.illinois.cs.cogcomp.lbjava.classify.Feature;
import edu.illinois.cs.cogcomp.lbjava.classify.FeatureVector;
import edu.illinois.cs.cogcomp.lbjava.classify.RealPrimitiveStringFeature;
import edu.illinois.cs.cogcomp.lbjava.classify.ScoreSet;

/**
 * Translates LBJava's internal problem representation into that which can be handled by WEKA
 * learning algorithms. This translation involves storing all examples in memory so they can be
 * passed to WEKA at one time.
 *
 * <p>
 * WEKA must be available on your <code>CLASSPATH</code> in order to use this class. WEKA source
 * code and pre-compiled jar distributions are available at: <a
 * href="http://www.cs.waikato.ac.nz/ml/weka/">http://www.cs.waikato.ac.nz/ml/weka/</a>
 *
 * <p>
 * To use this class in the <code>with</code> clause of a learning classifier expression, the
 * following restrictions must be recognized:
 * <ul>
 * <li>Feature pre-extraction must be enabled.
 * <li>No hard-coded feature generators may be referenced in the <code>using</code> clause.
 * <li>No array producing classifiers may be referenced in the <code>using</code> clause.
 * <li>The names of classifiers referenced in the <code>using</code> clause may not contain the
 * underscore character ('<code>_</code>').
 * <li>The values produced by discrete classifiers referenced in the <code>using</code> clause may
 * not contain the underscore, colon, or comma characters ('<code>_</code>', '<code>:</code>', or '
 * <code>,</code>').
 * </ul>
 *
 * <p>
 * To use this class in a Java application, the following restrictions must be recognized:
 * <ul>
 * <li> {@link #doneLearning()} must be called before calls to {@link #classify(Object)} can be made.
 * <li>After {@link #doneLearning()} is called, {@link #learn(Object)} may not be called without
 * first calling {@link #forget()}.
 * </ul>
 *
 * @author Dan Muriello
 **/
public class WekaWrapper extends Learner {
    /** Default for the {@link #attributeString} field. */
    public static final String defaultAttributeString = "";
    /** Default for the {@link #baseClassifier} field. */
    public static final weka.classifiers.Classifier defaultBaseClassifier = new weka.classifiers.bayes.NaiveBayes();

    /** A string encoding of the attributes used by this learner. */
    protected String attributeString;
    /**
     * Stores the instance of the WEKA classifier which we are training; default is
     * <code>weka.classifiers.bayes.NaiveBayes</code>.
     **/
    protected weka.classifiers.Classifier baseClassifier;
    /**
     * Stores a fresh instance of the WEKA classifier for the purposes of forgetting.
     **/
    protected weka.classifiers.Classifier freshClassifier;
    /**
     * Information about the features this learner takes as input is parsed from an attribute string
     * and stored here. This information is crucial in the task of interfacing with the WEKA
     * algorithms, and must be present before the {@link #learn(Object)} method can be called.
     *
     * <p>
     * Here is an example of a valid attribute string:
     * <code>nom_SimpleLabel_1,2,3,:str_First:nom_Second_a,b,c,d,r,t,:num_Third:</code>
     *
     * <p>
     * <code>nom</code> stands for "Nominal", i.e. the feature <code>SimpleLabel</code> was declared
     * as <code>discrete</code>, and had the value list <code>{"1","2","3"}</code>.
     *
     * <p>
     * <code>str</code> stands for "Stirng", i.e. the feature <code>First</code> was declared to be
     * <code>discrete</code>, but was not provided with a value list. When using the
     * <code>WekaWrapper</code>, it is best to provide value lists whenever possible, because very
     * few WEKA classifiers can handle string attributes.
     *
     * <p>
     * <code>num</code> stands for "Numerical", i.e. the feature <code>Third</code> was declared to
     * be <code>real</code>.
     **/
    protected FastVector attributeInfo = new FastVector();
    /** The main collection of Instance objects. */
    protected Instances instances;
    /**
     * Indicates whether the {@link #doneLearning()} method has been called and the
     * {@link #forget()} method has not yet been called.
     **/
    protected boolean trained = false;
    /** The label producing classifier's allowable values. */
    protected String[] allowableValues;

    /**
     * Empty constructor. Instantiates this wrapper with the default learning algorithm:
     * <code>weka.classifiers.bayes.NaiveBayes</code>. Attribute information must be provided before
     * any learning can occur.
     **/
    public WekaWrapper() {
        this("");
    }

    /**
     * Partial constructor; attribute information must be provided before any learning can occur.
     *
     * @param base The classifier to be used in this system.
     **/
    public WekaWrapper(weka.classifiers.Classifier base) {
        this("", base);
    }

    /**
     * Redirecting constructor.
     *
     * @param base The classifier to be used in this system.
     * @param attributeString The string describing the types of attributes example objects will
     *        have.
     **/
    public WekaWrapper(weka.classifiers.Classifier base, String attributeString) {
        this("", base, attributeString);
    }

    /**
     * Initializing constructor. Sets all member variables to their associated settings in the
     * {@link WekaWrapper.Parameters} object.
     *
     * @param p The settings of all parameters.
     **/
    public WekaWrapper(Parameters p) {
        this("", p);
    }

    /**
     * Empty constructor. Instantiates this wrapper with the default learning algorithm:
     * <code>weka.classifiers.bayes.NaiveBayes</code>. Attribute information must be provided before
     * any learning can occur.
     *
     * @param n The name of the classifier.
     **/
    public WekaWrapper(String n) {
        this(n, new Parameters());
    }

    /**
     * Partial constructor; attribute information must be provided before any learning can occur.
     *
     * @param base The classifier to be used in this system.
     **/
    public WekaWrapper(String n, weka.classifiers.Classifier base) {
        this(n, base, defaultAttributeString);
    }

    /**
     * Default Constructor. Instantiates this wrapper with the default learning algorithm:
     * <code>weka.classifiers.bayes.NaiveBayes</code>.
     *
     * @param n The name of the classifier.
     * @param attributeString The string describing the types of attributes example objects will
     *        have.
     **/
    public WekaWrapper(String n, String attributeString) {
        this(n, defaultBaseClassifier, attributeString);
    }

    /**
     * Initializing constructor. Sets all member variables to their associated settings in the
     * {@link WekaWrapper.Parameters} object.
     *
     * @param n The name of the classifier.
     * @param p The settings of all parameters.
     **/
    public WekaWrapper(String n, Parameters p) {
        super(n);
        setParameters(p);
        freshClassifier = baseClassifier;
    }

    /**
     * Full Constructor.
     *
     * @param n The name of the classifier
     * @param base The classifier to be used in this system.
     * @param attributeString The string describing the types of attributes example objects will
     *        have.
     **/
    public WekaWrapper(String n, weka.classifiers.Classifier base, String attributeString) {
        super(n);
        Parameters p = new Parameters();
        p.baseClassifier = base;
        p.attributeString = attributeString;
        setParameters(p);
        freshClassifier = base;
    }

    /**
     * Sets the values of parameters that control the behavior of this learning algorithm.
     *
     * @param p The parameters.
     **/
    public void setParameters(Parameters p) {
        baseClassifier = p.baseClassifier;
        attributeString = p.attributeString;
        initializeAttributes();
    }

    /**
     * Retrieves the parameters that are set in this learner.
     *
     * @return An object containing all the values of the parameters that control the behavior of
     *         this learning algorithm.
     **/
    public Learner.Parameters getParameters() {
        Parameters p = new Parameters(super.getParameters());
        p.baseClassifier = baseClassifier;
        p.attributeString = attributeString;
        return p;
    }

    /** This learner's output type is <code>"mixed%"</code>. */
    public String getOutputType() {
        return "mixed%";
    }

    /**
     * Takes <code>attributeString</code> and initializes this wrapper's {@link #instances}
     * collection to take those attributes.
     **/
    public void initializeAttributes() {
        String[] atts = attributeString.split(":");

        for (int i = 0; i < atts.length; ++i) {
            String[] parts = atts[i].split("_");

            if (parts[0].equals("str")) {
                String attributeName = parts[1];
                Attribute newAttribute = new Attribute(attributeName, (FastVector) null);
                attributeInfo.addElement(newAttribute);
            } else if (parts[0].equals("nom")) {
                String[] valueStrings = parts[2].split(",");
                FastVector valueVector = new FastVector(valueStrings.length);
                for (int j = 0; j < valueStrings.length; ++j)
                    valueVector.addElement(valueStrings[j]);

                Attribute a = new Attribute(parts[1], valueVector);
                attributeInfo.addElement(a);
            } else if (parts[0].equals("num")) {
                attributeInfo.addElement(new Attribute(parts[1]));
            } else {
                System.err
                        .println("WekaWrapper: Error - Malformed attribute information string: " + attributeString);
                new Exception().printStackTrace();
                System.exit(1);
            }
        }

        instances = new Instances(name, attributeInfo, 0);
        instances.setClassIndex(0);
    }

    /**
     * Sets the labeler.
     *
     * @param l A labeling classifier.
     **/
    public void setLabeler(Classifier l) {
        super.setLabeler(l);
        allowableValues = l == null ? null : l.allowableValues();
    }

    /**
     * Returns the array of allowable values that a feature returned by this classifier may take.
     *
     * @return The allowable values of this learner's labeler, or an array of length zero if the
     *         labeler has not yet been established or does not specify allowable values.
     **/
    public String[] allowableValues() {
        if (allowableValues == null)
            return new String[0];
        return allowableValues;
    }

    /**
     * Since WEKA classifiers cannot learn online, this method causes no actual learning to occur,
     * it simply creates an <code>Instance</code> object from this example and adds it to a set of
     * examples from which the classifier will be built once {@link #doneLearning()} is called.
     **/
    public void learn(int[] exampleFeatures, double[] exampleValues, int[] exampleLabels, double[] labelValues) {
        instances.add(makeInstance(exampleFeatures, exampleValues, exampleLabels, labelValues));
    }

    /**
     * This method makes one or more decisions about a single object, returning those decisions as
     * Features in a vector.
     *
     * @param exampleFeatures The example's array of feature indices.
     * @param exampleValues The example's array of feature values.
     * @return A feature vector with a single feature containing the prediction for this example.
     **/
    public FeatureVector classify(int[] exampleFeatures, double[] exampleValues) {
        if (!trained) {
            System.err.println(
                    "WekaWrapper: Error - Cannot make a classification with an " + "untrained classifier.");
            new Exception().printStackTrace();
            System.exit(1);
        }

        /*
         * Assuming that the first Attribute in our attributeInfo vector is the class attribute,
         * decide which case we are in
         */
        Attribute classAtt = (Attribute) attributeInfo.elementAt(0);

        if (classAtt.isNominal() || classAtt.isString()) {
            double[] dist = getDistribution(exampleFeatures, exampleValues);
            int best = 0;
            for (int i = 1; i < dist.length; ++i)
                if (dist[i] > dist[best])
                    best = i;

            Feature label = labelLexicon.lookupKey(best);
            if (label == null)
                return new FeatureVector();
            String value = label.getStringValue();

            return new FeatureVector(new DiscretePrimitiveStringFeature(containingPackage, name, "", value,
                    valueIndexOf(value), (short) allowableValues().length));
        } else if (classAtt.isNumeric()) {
            return new FeatureVector(new RealPrimitiveStringFeature(containingPackage, name, "",
                    getDistribution(exampleFeatures, exampleValues)[0]));
        } else {
            System.err.println("WekaWrapper: Error - illegal class type.");
            new Exception().printStackTrace();
            System.exit(1);
        }

        return new FeatureVector();
    }

    /**
     * Returns a discrete distribution of the classifier's prediction values.
     *
     * @param exampleFeatures The example's array of feature indices.
     * @param exampleValues The example's array of feature values.
     **/
    protected double[] getDistribution(int[] exampleFeatures, double[] exampleValues) {
        if (!trained) {
            System.err.println(
                    "WekaWrapper: Error - Cannot make a classification with an " + "untrained classifier.");
            new Exception().printStackTrace();
            System.exit(1);
        }

        Instance inQuestion = makeInstance(exampleFeatures, exampleValues, new int[0], new double[0]);

        /*
         * For Numerical class values, this will return an array of size 1, containing the class
         * prediction. For Nominal classes, an array of size equal to that of the class list,
         * representing probabilities. For String classes, ?
         */
        double[] dist = null;
        try {
            dist = baseClassifier.distributionForInstance(inQuestion);
        } catch (Exception e) {
            System.err.println("WekaWrapper: Error while computing distribution.");
            e.printStackTrace();
            System.exit(1);
        }

        if (dist.length == 0) {
            System.err.println("WekaWrapper: Error - The base classifier returned an empty "
                    + "probability distribution when attempting to classify an " + "example.");
            new Exception().printStackTrace();
            System.exit(1);
        }

        return dist;
    }

    /**
     * Destroys the learned version of the WEKA classifier and empties the {@link #instances}
     * collection of examples.
     **/
    public void forget() {
        super.forget();

        try {
            baseClassifier = weka.classifiers.Classifier.makeCopy(freshClassifier);
        } catch (Exception e) {
            System.err.println("LBJava ERROR: WekaWrapper.forget: Can't copy classifier:");
            e.printStackTrace();
            System.exit(1);
        }

        instances = new Instances(name, attributeInfo, 0);
        instances.setClassIndex(0);
        trained = false;
    }

    /**
     * Creates a WEKA Instance object out of a {@link FeatureVector}.
     **/
    private Instance makeInstance(int[] exampleFeatures, double[] exampleValues, int[] exampleLabels,
            double[] labelValues) {
        // Make sure attributeInfo has been filled
        if (attributeInfo.size() == 0) {
            System.err.println("WekaWrapper: Error - makeInstance was called while attributeInfo " + "was empty.");
            new Exception().printStackTrace();
            System.exit(1);
        }

        // Initialize an Instance object
        Instance inst = new Instance(attributeInfo.size());

        // Acknowledge that this instance will be a member of our dataset
        // 'instances'
        inst.setDataset(instances);

        // Assign values for its attributes
        /*
         * Since we are iterating through this example's feature list, which does not contain the
         * label feature (the label feature is the first in the 'attribute' list), we start attIndex
         * at 1, while we start featureIndex at 0.
         */
        for (int featureIndex = 0, attIndex = 1; featureIndex < exampleFeatures.length; ++featureIndex, ++attIndex) {
            Feature f = (Feature) lexicon.lookupKey(exampleFeatures[featureIndex]);
            Attribute att = (Attribute) attributeInfo.elementAt(attIndex);

            // make sure the feature's identifier and the attribute's name match
            if (!(att.name().equals(f.getStringIdentifier()))) {
                System.err.println(
                        "WekaWrapper: Error - makeInstance encountered a misaligned " + "attribute-feature pair.");
                System.err.println(
                        "  " + att.name() + " and " + f.getStringIdentifier() + " should have been identical.");
                new Exception().printStackTrace();
                System.exit(1);
            }

            if (!f.isDiscrete())
                inst.setValue(attIndex, exampleValues[featureIndex]);
            else { // it's a discrete or conjunctive feature.
                String attValue = f.totalValues() == 2 ? att.value((int) exampleValues[featureIndex])
                        : f.getStringValue();
                inst.setValue(attIndex, attValue);
            }
        }

        /*
         * Here, we assume that if either the labels FeatureVector is empty of features, or is null,
         * then this example is to be considered unlabeled.
         */
        if (exampleLabels.length == 0) {
            inst.setClassMissing();
        } else if (exampleLabels.length > 1) {
            System.err.println("WekaWrapper: Error - Weka Instances may only take a single class " + "value, ");
            new Exception().printStackTrace();
            System.exit(1);
        } else {
            Feature label = labelLexicon.lookupKey(exampleLabels[0]);

            // make sure the name of the label feature matches the name of the 0'th
            // attribute
            if (!(label.getStringIdentifier().equals(((Attribute) attributeInfo.elementAt(0)).name()))) {
                System.err.println("WekaWrapper: Error - makeInstance found the wrong label name.");
                new Exception().printStackTrace();
                System.exit(1);
            }

            if (!label.isDiscrete())
                inst.setValue(0, labelValues[0]);
            else
                inst.setValue(0, label.getStringValue());
        }

        return inst;
    }

    /**
     * Produces a set of scores indicating the degree to which each possible discrete classification
     * value is associated with the given example object.
     **/
    public ScoreSet scores(int[] exampleFeatures, double[] exampleValues) {
        double[] dist = getDistribution(exampleFeatures, exampleValues);

        /*
         * Assuming that the first Attribute in our attributeInfo vector is the class attribute,
         * decide which case we are in
         */
        Attribute classAtt = (Attribute) attributeInfo.elementAt(0);

        ScoreSet scores = new ScoreSet();

        if (classAtt.isNominal() || classAtt.isString()) {
            Enumeration enumeratedValues = classAtt.enumerateValues();

            int i = 0;
            while (enumeratedValues.hasMoreElements()) {
                if (i >= dist.length) {
                    System.err.println(
                            "WekaWrapper: Error - scores found more possible values than " + "probabilities.");
                    new Exception().printStackTrace();
                    System.exit(1);
                }
                double s = dist[i];
                String v = (String) enumeratedValues.nextElement();
                scores.put(v, s);
                ++i;
            }
        } else if (classAtt.isNumeric()) {
            System.err.println("WekaWrapper: Error - The 'scores' function should not be called "
                    + "when the class attribute is numeric.");
            new Exception().printStackTrace();
            System.exit(1);
        } else {
            System.err.println(
                    "WekaWrapper: Error - ScoreSet: Class Types must be either " + "Nominal, String, or Numeric.");
            new Exception().printStackTrace();
            System.exit(1);
        }

        return scores;
    }

    /**
     * Indicates that the classifier is finished learning. This method <I>must</I> be called if the
     * WEKA classifier is to learn anything. Since WEKA classifiers cannot learn online, all of the
     * training examples must be gathered and committed to first. This method invokes the WEKA
     * classifier's <code>buildClassifier(Instances)</code> method.
     **/
    public void doneLearning() {
        if (trained) {
            System.err.println("WekaWrapper: Error - Cannot call 'doneLearning()' again without "
                    + "first calling 'forget()'");
            new Exception().printStackTrace();
            System.exit(1);
        }

        /*
         * System.out.println("\nWekaWrapper Data Summary:");
         * System.out.println(instances.toSummaryString());
         */

        try {
            baseClassifier.buildClassifier(instances);
        } catch (Exception e) {
            System.err.println("WekaWrapper: Error - There was a problem building the classifier");
            if (baseClassifier == null)
                System.out.println("WekaWrapper: baseClassifier was null.");
            e.printStackTrace();
            System.exit(1);
        }

        trained = true;
        instances = new Instances(name, attributeInfo, 0);
        instances.setClassIndex(0);
    }

    /**
     * Writes the settings of the classifier in use, and a string describing the classifier, if
     * available.
     **/
    public void write(PrintStream out) {
        out.print(name + ": ");
        String[] options = baseClassifier.getOptions();
        for (int i = 0; i < options.length; ++i)
            out.println(options[i]);
        out.println(baseClassifier);
    }

    /**
     * Writes the learned function's internal representation in binary form.
     *
     * @param out The output stream.
     **/
    public void write(ExceptionlessOutputStream out) {
        super.write(out);
        out.writeBoolean(trained);

        if (allowableValues == null)
            out.writeInt(0);
        else {
            out.writeInt(allowableValues.length);
            for (int i = 0; i < allowableValues.length; ++i)
                out.writeString(allowableValues[i]);
        }

        ObjectOutputStream oos = null;
        try {
            oos = new ObjectOutputStream(out);
        } catch (Exception e) {
            System.err.println("Can't create object stream for '" + name + "': " + e);
            System.exit(1);
        }

        try {
            oos.writeObject(baseClassifier);
            oos.writeObject(freshClassifier);
            oos.writeObject(attributeInfo);
            oos.writeObject(instances);
        } catch (Exception e) {
            System.err.println("Can't write to object stream for '" + name + "': " + e);
            System.exit(1);
        }
    }

    /**
     * Reads the binary representation of a learner with this object's run-time type, overwriting
     * any and all learned or manually specified parameters as well as the label lexicon but without
     * modifying the feature lexicon.
     *
     * @param in The input stream.
     **/
    public void read(ExceptionlessInputStream in) {
        super.read(in);
        trained = in.readBoolean();
        allowableValues = new String[in.readInt()];
        for (int i = 0; i < allowableValues.length; ++i)
            allowableValues[i] = in.readString();

        ObjectInputStream ois = null;
        try {
            ois = new ObjectInputStream(in);
        } catch (Exception e) {
            System.err.println("Can't create object stream for '" + name + "': " + e);
            System.exit(1);
        }

        try {
            baseClassifier = (weka.classifiers.Classifier) ois.readObject();
            freshClassifier = (weka.classifiers.Classifier) ois.readObject();
            attributeInfo = (FastVector) ois.readObject();
            instances = (Instances) ois.readObject();
        } catch (Exception e) {
            System.err.println("Can't read from object stream for '" + name + "': " + e);
            System.exit(1);
        }
    }

    /**
     * Simply a container for all of {@link WekaWrapper}'s configurable parameters. Using instances
     * of this class should make code more readable and constructors less complicated.
     *
     * @author Nick Rizzolo
     **/
    public static class Parameters extends Learner.Parameters {
        /**
         * Stores the instance of the WEKA classifier which we are training; default
         * {@link WekaWrapper#defaultBaseClassifier}.
         **/
        public weka.classifiers.Classifier baseClassifier;
        /**
         * A string encoding of the return types of each of the feature extractors in use; default
         * {@link WekaWrapper#defaultAttributeString}.
         **/
        public String attributeString;

        /** Sets all the default values. */
        public Parameters() {
            baseClassifier = defaultBaseClassifier;
            attributeString = defaultAttributeString;
        }

        /**
         * Sets the parameters from the parent's parameters object, giving defaults to all
         * parameters declared in this object.
         **/
        public Parameters(Learner.Parameters p) {
            super(p);
            baseClassifier = defaultBaseClassifier;
            attributeString = defaultAttributeString;
        }

        /** Copy constructor. */
        public Parameters(Parameters p) {
            super(p);
            baseClassifier = p.baseClassifier;
            attributeString = p.attributeString;
        }

        /**
         * Calls the appropriate <code>Learner.setParameters(Parameters)</code> method for this
         * <code>Parameters</code> object.
         *
         * @param l The learner whose parameters will be set.
         **/
        public void setParameters(Learner l) {
            ((WekaWrapper) l).setParameters(this);
        }

        /**
         * Creates a string representation of these parameters in which only those parameters that
         * differ from their default values are mentioned.
         **/
        public String nonDefaultString() {
            String result = super.nonDefaultString();

            if (!attributeString.equals(WekaWrapper.defaultAttributeString))
                result += ", attributeString = \"" + attributeString + "\"";

            if (result.startsWith(", "))
                result = result.substring(2);
            return result;
        }
    }
}