myclassifier.myC45Pack.MyClassifierTree.java Source code

Java tutorial

Introduction

Here is the source code for myclassifier.myC45Pack.MyClassifierTree.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package myclassifier.myC45Pack;

import weka.classifiers.trees.j48.ClassifierSplitModel;
import weka.classifiers.trees.j48.Distribution;
import weka.classifiers.trees.j48.ModelSelection;
import weka.core.Capabilities;
import weka.core.CapabilitiesHandler;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;

/**
 *
 * @author Fahmi
 */
public class MyClassifierTree implements CapabilitiesHandler {
    /** for serialization */
    static final long serialVersionUID = -8722249377542734193L;

    /** The model selection method. */
    protected ModelSelection toSelectModel;

    /** Local model at node. */
    protected ClassifierSplitModel localModel;

    /** References to sons. */
    protected MyClassifierTree[] childTree;

    /** True if node is leaf. */
    protected boolean isLeaf;

    /** True if node is empty. */
    protected boolean isEmpty;

    /** The training instances. */
    protected Instances train;

    /** The pruning instances. */
    protected Distribution test;

    /** The id for the node. */
    protected int id;

    /** 
     * For getting a unique ID when outputting the tree (hash code isn't
     * guaranteed unique) 
     */
    private static long PRINTED_NODES = 0;

    /**
     * Gets the next unique node ID.
     *
     * @return the next unique node ID.
     */
    protected static long nextID() {
        return PRINTED_NODES++;
    }

    /**
     * Resets the unique node ID counter (e.g.
     * between repeated separate print types)
     */
    protected static void resetID() {
        PRINTED_NODES = 0;
    }

    /**
     * Constructor. 
     */
    public MyClassifierTree(ModelSelection toSelectLocModel) {
        toSelectModel = toSelectLocModel;
    }

    /**
     * Returns default capabilities of the classifier tree.
     *
     * @return  the capabilities of this classifier tree
     */
    public Capabilities getCapabilities() {
        Capabilities result = new Capabilities(this);
        result.enableAll();

        return result;
    }

    /**
     * Method for building a classifier tree.
     *
     * @param data the data to build the tree from
     * @throws Exception if something goes wrong
     */
    public void buildClassifier(Instances data) throws Exception {

        // can classifier tree handle the data?
        getCapabilities().testWithFail(data);

        // remove instances with missing class
        data = new Instances(data);
        data.deleteWithMissingClass();

        buildTree(data, false);
    }

    /**
     * Builds the tree structure.
     *
     * @param data the data for which the tree structure is to be
     * generated.
     * @param keepData is training data to be kept?
     * @throws Exception if something goes wrong
     */
    public void buildTree(Instances data, boolean keepData) throws Exception {
        Instances[] localInstances;
        if (keepData) {
            train = data;
        }
        test = null;
        isLeaf = false;
        isEmpty = false;
        childTree = null;
        localModel = toSelectModel.selectModel(data);
        if (localModel.numSubsets() > 1) {
            localInstances = localModel.split(data);
            data = null;
            childTree = new MyClassifierTree[localModel.numSubsets()];
            for (int i = 0; i < childTree.length; i++) {
                childTree[i] = getNewTree(localInstances[i]);
                localInstances[i] = null;
            }
        } else {
            isLeaf = true;
            if (Utils.eq(data.sumOfWeights(), 0)) {
                isEmpty = true;
            }
            data = null;
        }
    }

    /**
     * Builds the tree structure with hold out set
     *
     * @param train the data for which the tree structure is to be
     * generated.
     * @param test the test data for potential pruning
     * @param keepData is training Data to be kept?
     * @throws Exception if something goes wrong
     */
    public void buildTree(Instances train, Instances test, boolean keepData) throws Exception {
        //local variable
        Instances[] localTrain, localTest;
        int i;

        if (keepData) {
            this.train = train;
        }
        isLeaf = false;
        isEmpty = false;
        childTree = null;
        localModel = toSelectModel.selectModel(train, test);
        this.test = new Distribution(test, localModel);
        if (localModel.numSubsets() > 1) {
            localTrain = localModel.split(train);
            localTest = localModel.split(test);
            train = test = null;
            childTree = new MyClassifierTree[localModel.numSubsets()];
            for (i = 0; i < childTree.length; i++) {
                childTree[i] = getNewTree(localTrain[i], localTest[i]);
                localTrain[i] = null;
                localTest[i] = null;
            }
        } else {
            //tidak ada 
            isLeaf = true;
            if (Utils.eq(train.sumOfWeights(), 0))
                isEmpty = true;
            train = test = null;
        }
    }

    /**
     * Returns a newly created tree.
     *
     * @param data the training data
     * @return the generated tree
     * @throws Exception if something goes wrong
     */
    protected MyClassifierTree getNewTree(Instances data) throws Exception {
        MyClassifierTree newTree = new MyClassifierTree(toSelectModel);
        newTree.buildTree(data, false);

        return newTree;
    }

    /**
     * Returns a newly created tree.
     *
     * @param train the training data
     * @param test the pruning data.
     * @return the generated tree
     * @throws Exception if something goes wrong
     */
    protected MyClassifierTree getNewTree(Instances train, Instances test) throws Exception {

        MyClassifierTree newTree = new MyClassifierTree(toSelectModel);
        newTree.buildTree(train, test, false);

        return newTree;
    }

    /** 
     * Classifies an instance.
     *
     * @param instance the instance to classify
     * @return the classification
     * @throws Exception if something goes wrong
     */
    public double classifyInstance(Instance instance) throws Exception {

        double maxProb = -1;
        double currentProb;
        int maxIndex = 0;
        int j;

        for (j = 0; j < instance.numClasses(); j++) {
            currentProb = getProbs(j, instance, 1);
            if (Utils.gr(currentProb, maxProb)) {
                maxIndex = j;
                maxProb = currentProb;
            }
        }

        return (double) maxIndex;
    }

    /**
     * Cleanup in order to save memory.
     * 
     * @param justHeaderInfo
     */
    public final void cleanup(Instances justHeaderInfo) {

        train = justHeaderInfo;
        test = null;
        if (!isLeaf) {
            for (int i = 0; i < childTree.length; i++) {
                childTree[i].cleanup(justHeaderInfo);
            }
        }
    }

    /** 
     * Returns class probabilities for a weighted instance.
     *
     * @param instance the instance to get the distribution for
     * @param useLaplace whether to use laplace or not
     * @return the distribution
     * @throws Exception if something goes wrong
     */
    public final double[] distributionForInstance(Instance instance, boolean useLaplace) throws Exception {

        double[] doubles = new double[instance.numClasses()];

        for (int i = 0; i < doubles.length; i++) {
            if (!useLaplace) {
                doubles[i] = getProbs(i, instance, 1);
            } else {
                doubles[i] = getProbsLaplace(i, instance, 1);
            }
        }
        return doubles;
    }

    /**
     * Assigns a unique id to every node in the tree.
     * 
     * @param lastID the last ID that was assign
     * @return the new current ID
     */
    public int assignIDs(int lastID) {

        int currLastID = lastID + 1;

        id = currLastID;
        if (childTree != null) {
            for (int i = 0; i < childTree.length; i++) {
                currLastID = childTree[i].assignIDs(currLastID);
            }
        }
        return currLastID;
    }

    /**
     * Returns graph describing the tree.
     *
     * @throws Exception if something goes wrong
     * @return the tree as graph
     */
    /*public String graph() throws Exception {
        
      StringBuffer text = new StringBuffer();
        
      assignIDs(-1);
      text.append("digraph J48Tree {\n");
      if (isLeaf) {
    text.append("N" + id 
                + " [label=\"" + 
                Utils.quote(localModel.dumpLabel(0,train)) + "\" " + 
                "shape=box style=filled ");
    if (train != null && train.numInstances() > 0) {
      text.append("data =\n" + train + "\n");
      text.append(",\n");
        
    }
    text.append("]\n");
      }else {
    text.append("N" + id 
                + " [label=\"" + 
                Utils.quote(localModel.leftSide(train)) + "\" ");
    if (train != null && train.numInstances() > 0) {
      text.append("data =\n" + train + "\n");
      text.append(",\n");
       }
    text.append("]\n");
    graphTree(text);
      }
        
      return text.toString() +"}\n";
    }*/

    /**
     * Returns tree in prefix order.
     *
     * @throws Exception if something goes wrong
     * @return the prefix order
     */
    /*public String prefix() throws Exception {
        
      StringBuffer text;
        
      text = new StringBuffer();
      if (isLeaf) {
    text.append("["+localModel.dumpLabel(0,train)+"]");
      }else {
    prefixTree(text);
      }
        
      return text.toString();
    }*/

    /**
     * Returns source code for the tree as an if-then statement. The 
     * class is assigned to variable "p", and assumes the tested 
     * instance is named "i". The results are returned as two string buffers: 
     * a section of code for assignment of the class, and a section of
     * code containing support code (eg: other support methods).
     *
     * @param className the class name that this static classifier has
     * @return an array containing two string buffers, the first string containing
     * assignment code, and the second containing source for support code.
     * @throws Exception if something goes wrong
     */
    /*public StringBuffer [] toSource(String className) throws Exception {
        
      StringBuffer [] result = new StringBuffer [2];
      if (isLeaf) {
    result[0] = new StringBuffer("    p = " 
      + localModel.distribution().maxClass(0) + ";\n");
    result[1] = new StringBuffer("");
      } else {
    StringBuffer text = new StringBuffer();
    StringBuffer atEnd = new StringBuffer();
        
    long printID = MyClassifierTree.nextID();
        
    text.append("  static double N") 
      .append(Integer.toHexString(localModel.hashCode()) + printID)
      .append("(Object []i) {\n")
      .append("    double p = Double.NaN;\n");
        
    text.append("    if (")
      .append(localModel.sourceExpression(-1, train))
      .append(") {\n");
    text.append("      p = ")
      .append(localModel.distribution().maxClass(0))
      .append(";\n");
    text.append("    } ");
    for (int i = 0; i < childTree.length; i++) {
      text.append("else if (" + localModel.sourceExpression(i, train) 
                  + ") {\n");
      if (childTree[i].isLeaf) {
        text.append("      p = " 
                    + localModel.distribution().maxClass(i) + ";\n");
      } else {
        StringBuffer [] sub = childTree[i].toSource(className);
        text.append(sub[0]);
        atEnd.append(sub[1]);
      }
      text.append("    } ");
      if (i == childTree.length - 1) {
        text.append('\n');
      }
    }
        
    text.append("    return p;\n  }\n");
        
    result[0] = new StringBuffer("    p = " + className + ".N");
    result[0].append(Integer.toHexString(localModel.hashCode()) +  printID)
      .append("(i);\n");
    result[1] = text.append(atEnd);
      }
      return result;
    }*/

    /**
     * Returns number of leaves in tree structure.
     * 
     * @return the number of leaves
     */
    public int numLeaves() {
        int N = 0;

        if (isLeaf) {
            return 1;
        } else {
            for (int i = 0; i < childTree.length; i++) {
                N = N + childTree[i].numLeaves();
            }
        }
        return N;
    }

    /**
     * Returns number of nodes in tree structure.
     * 
     * @return the number of nodes
     */
    public int numNodes() {
        int no = 1;

        if (!isLeaf) {
            for (int i = 0; i < childTree.length; i++) {
                no = no + childTree[i].numNodes();
            }
        }
        return no;
    }

    /**
     * Prints tree structure.
     * 
     * @return the tree structure
     */
    @Override
    public String toString() {
        try {
            StringBuffer text = new StringBuffer();
            if (isLeaf) {
                text.append(": ");
                text.append(localModel.dumpLabel(0, train));
            } else {
                dumpTree(0, text);
            }
            text.append("\n\nNumber of Leaves  : \t" + numLeaves() + "\n");
            text.append("\nSize of the tree : \t" + numNodes() + "\n");
            return text.toString();
        } catch (Exception e) {
            return "Can't print classification tree.";
        }
    }

    /**
     * Help method for printing tree structure.
     *
     * @param depth the current depth
     * @param text for outputting the structure
     * @throws Exception if something goes wrong
     */
    private void dumpTree(int depth, StringBuffer text) throws Exception {
        int i, j;

        for (i = 0; i < childTree.length; i++) {
            text.append("\n");
            ;
            for (j = 0; j < depth; j++) {
                text.append("|   ");
            }
            text.append(localModel.leftSide(train));
            text.append(localModel.rightSide(i, train));
            if (childTree[i].isLeaf) {
                text.append(": ");
                text.append(localModel.dumpLabel(i, train));
            } else {
                childTree[i].dumpTree(depth + 1, text);
            }
        }
    }

    /**
     * Help method for printing tree structure as a graph.
     *
     * @param text for outputting the tree
     * @throws Exception if something goes wrong
     */
    /*private void graphTree(StringBuffer text) throws Exception {
        
      for (int i = 0; i < childTree.length; i++) {
    text.append("N" + id  
                + "->" + 
                "N" + childTree[i].id +
                " [label=\"" + Utils.quote(localModel.rightSide(i,train).trim()) + 
                "\"]\n");
    if (childTree[i].isLeaf) {
      text.append("N" + childTree[i].id +
                  " [label=\""+ Utils.quote(localModel.dumpLabel(i,train))+"\" "+ 
                  "shape=box style=filled ");
      if (train != null && train.numInstances() > 0) {
        text.append("data =\n" + childTree[i].train + "\n");
        text.append(",\n");
      }
      text.append("]\n");
    } else {
      text.append("N" + childTree[i].id +
                  " [label=\""+ Utils.quote(childTree[i].localModel.leftSide(train))+ 
                  "\" ");
      if (train != null && train.numInstances() > 0) {
        text.append("data =\n" + childTree[i].train + "\n");
        text.append(",\n");
      }
      text.append("]\n");
      childTree[i].graphTree(text);
    }
      }
    }*/

    /**
     * Prints the tree in prefix form
     * 
     * @param text the buffer to output the prefix form to
     * @throws Exception if something goes wrong
     */
    /*private void prefixTree(StringBuffer text) throws Exception {
        
      text.append("[");
      text.append(localModel.leftSide(train)+":");
      for (int i = 0; i < childTree.length; i++) {
    if (i > 0) {
      text.append(",\n");
    }
    text.append(localModel.rightSide(i, train));
      }
      for (int i = 0; i < childTree.length; i++) {
    if (childTree[i].isLeaf) {
      text.append("[");
      text.append(localModel.dumpLabel(i,train));
      text.append("]");
    } else {
      childTree[i].prefixTree(text);
    }
      }
      text.append("]");
    }*/

    /**
     * Help method for computing class probabilities of 
     * a given instance.
     *
     * @param classIndex the class index
     * @param instance the instance to compute the probabilities for
     * @param weight the weight to use
     * @return the laplace probs
     * @throws Exception if something goes wrong
     */
    private double getProbsLaplace(int classIndex, Instance instance, double weight) throws Exception {

        double prob = 0;

        if (isLeaf) {
            return weight * localModel.classProbLaplace(classIndex, instance, -1);
        } else {
            int treeIndex = localModel.whichSubset(instance);
            if (treeIndex == -1) {
                double[] weights = localModel.weights(instance);
                for (int i = 0; i < childTree.length; i++) {
                    if (!child(i).isEmpty) {
                        prob += child(i).getProbsLaplace(classIndex, instance, weights[i] * weight);
                    }
                }
                return prob;
            } else {
                if (child(treeIndex).isEmpty) {
                    return weight * localModel.classProbLaplace(classIndex, instance, treeIndex);
                } else {
                    return child(treeIndex).getProbsLaplace(classIndex, instance, weight);
                }
            }
        }
    }

    /**
     * Help method for computing class probabilities of 
     * a given instance.
     * 
     * @param classIndex the class index
     * @param instance the instance to compute the probabilities for
     * @param weight the weight to use
     * @return the probs
     * @throws Exception if something goes wrong
     */
    private double getProbs(int classIndex, Instance instance, double weight) throws Exception {

        double prob = 0;

        if (isLeaf) {
            return weight * localModel.classProb(classIndex, instance, -1);
        } else {
            int treeIndex = localModel.whichSubset(instance);
            if (treeIndex == -1) {
                double[] weights = localModel.weights(instance);
                for (int i = 0; i < childTree.length; i++) {
                    if (!child(i).isEmpty) {
                        prob += child(i).getProbs(classIndex, instance, weights[i] * weight);
                    }
                }
                return prob;
            } else {
                if (child(treeIndex).isEmpty) {
                    return weight * localModel.classProb(classIndex, instance, treeIndex);
                } else {
                    return child(treeIndex).getProbs(classIndex, instance, weight);
                }
            }
        }
    }

    /**
     * Method just exists to make program easier to read.
     */
    /*private ClassifierSplitModel localModel() {
    return (ClassifierSplitModel)localModel;
    }*/

    /**
     * Method just exists to make program easier to read.
     */
    private MyClassifierTree child(int index) {
        return (MyClassifierTree) childTree[index];
    }
}