j48.C45Split.java Source code

Java tutorial

Introduction

Here is the source code for j48.C45Split.java

Source

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    C45Split.java
 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
 *
 */

package j48;

import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;

import java.util.Enumeration;

/**
 * Class implementing a C4.5-type split on an attribute.
 * 
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @version $Revision: 1.13 $
 */
public class C45Split extends ClassifierSplitModel {

    /** for serialization */
    private static final long serialVersionUID = 3064079330067903161L;

    /** Desired number of branches. */
    private int m_complexityIndex;

    /** Attribute to split on. */
    private int m_attIndex;

    /** Minimum number of objects in a split. */
    private int m_minNoObj;

    /** Value of split point. */
    private double m_splitPoint;

    /** InfoGain of split. */
    private double m_infoGain;

    /** GainRatio of split. */
    private double m_gainRatio;

    /** The sum of the weights of the instances. */
    private double m_sumOfWeights;

    /** Number of split points. */
    private int m_index;

    /** Static reference to splitting criterion. */
    private static InfoGainSplitCrit infoGainCrit = new InfoGainSplitCrit();

    /** Static reference to splitting criterion. */
    private static GainRatioSplitCrit gainRatioCrit = new GainRatioSplitCrit();

    /**
     * Initializes the split model.
     */
    public C45Split(int attIndex, int minNoObj, double sumOfWeights) {

        // Get index of attribute to split on.
        m_attIndex = attIndex;

        // Set minimum number of objects.
        m_minNoObj = minNoObj;

        // Set the sum of the weights
        m_sumOfWeights = sumOfWeights;
    }

    /**
     * Creates a C4.5-type split on the given data. Assumes that none of the
     * class values is missing.
     * 
     * @exception Exception
     *                if something goes wrong
     */
    private double rrrrr = 0.0, lllll = 0.0;

    public void buildClassifier(Instances trainInstances) throws Exception {

        // Initialize the remaining instance variables.
        m_numSubsets = 0;
        m_splitPoint = Double.MAX_VALUE;
        m_infoGain = 0;
        m_gainRatio = 0;

        // Different treatment for enumerated and numeric
        // attributes.
        if (trainInstances.attribute(m_attIndex).isNominal()) {
            m_complexityIndex = trainInstances.attribute(m_attIndex).numValues();
            m_index = m_complexityIndex;
            handleEnumeratedAttribute(trainInstances);

        } else {
            m_complexityIndex = 2;
            m_index = 0;
            trainInstances.sort(trainInstances.attribute(m_attIndex));

            // ///////////////////////////////////////////////////////////////////////////////////////
            double stdDev = trainInstances.attributeStats(m_attIndex).numericStats.stdDev;
            if (stdDev > 200) {
                //      rrrrr = stdDev/200;
                //      System.out.println(stdDev+" ");
                rrrrr = Math.log10(stdDev) / 1.2;
                //      rrrrr = 1.1;
                //      lllll = stdDev/2000;

                //      lllll = 0.3;

                lllll = Math.log10(stdDev) / 8;
            } else {
                lllll = Math.log10(stdDev) / 1.2;
                //         lllll = stdDev/200;
                //         lllll = 1.1;

                //         rrrrr = stdDev/2000;
                //         rrrrr = 0.3;
                rrrrr = Math.log10(stdDev) / 8;

            }
            handleNumericAttribute(trainInstances);
        }
    }

    /**
     * Returns index of attribute for which split was generated.
     */
    public final int attIndex() {

        return m_attIndex;
    }

    /**
     * Gets class probability for instance.
     * 
     * @exception Exception
     *                if something goes wrong
     */
    public final double classProb(int classIndex, Instance instance, int theSubset) throws Exception {

        if (theSubset <= -1) {
            double[] weights = weights(instance);
            if (weights == null) {
                return m_distribution.prob(classIndex);
            } else {
                double prob = 0;
                for (int i = 0; i < weights.length; i++) {
                    prob += weights[i] * m_distribution.prob(classIndex, i);
                }
                return prob;
            }
        } else {
            if (Utils.gr(m_distribution.perBag(theSubset), 0)) {
                return m_distribution.prob(classIndex, theSubset);
            } else {
                return m_distribution.prob(classIndex);
            }
        }
    }

    /**
     * Returns coding cost for split (used in rule learner).
     */
    public final double codingCost() {

        return Utils.log2(m_index);
    }

    /**
     * Returns (C4.5-type) gain ratio for the generated split.
     */
    public final double gainRatio() {
        return m_gainRatio;
    }

    /**
     * Creates split on enumerated attribute.
     * 
     * @exception Exception
     *                if something goes wrong
     */
    private void handleEnumeratedAttribute(Instances trainInstances) throws Exception {

        Instance instance;

        m_distribution = new Distribution(m_complexityIndex, trainInstances.numClasses());

        // Only Instances with known values are relevant.
        Enumeration enu = trainInstances.enumerateInstances();
        while (enu.hasMoreElements()) {
            instance = (Instance) enu.nextElement();
            if (!instance.isMissing(m_attIndex))
                m_distribution.add((int) instance.value(m_attIndex), instance);
        }

        // Check if minimum number of Instances in at least two
        // subsets.
        if (m_distribution.check(m_minNoObj)) {
            m_numSubsets = m_complexityIndex;
            m_infoGain = infoGainCrit.splitCritValue(m_distribution, m_sumOfWeights);
            m_gainRatio = gainRatioCrit.splitCritValue(m_distribution, m_sumOfWeights, m_infoGain);
        }
    }

    /**
     * Creates split on numeric attribute.
     * 
     * @exception Exception
     *                if something goes wrong
     */
    private void handleNumericAttribute(Instances trainInstances) throws Exception {

        int firstMiss;
        int next = 1;
        int last = 0;
        int splitIndex = -1;
        double currentInfoGain;
        double defaultEnt;
        double minSplit;
        Instance instance;
        int i;

        // Current attribute is a numeric attribute.
        m_distribution = new Distribution(2, trainInstances.numClasses());

        // Only Instances with known values are relevant.
        Enumeration enu = trainInstances.enumerateInstances();
        i = 0;
        while (enu.hasMoreElements()) {
            instance = (Instance) enu.nextElement();
            if (instance.isMissing(m_attIndex))
                break;
            m_distribution.add(1, instance);
            i++;
        }
        firstMiss = i;

        // Compute minimum number of Instances required in each
        // subset.
        minSplit = 0.1 * (m_distribution.total()) / ((double) trainInstances.numClasses());
        if (Utils.smOrEq(minSplit, m_minNoObj))
            minSplit = m_minNoObj;
        else if (Utils.gr(minSplit, 25))
            minSplit = 25;

        // Enough Instances with known values?
        if (Utils.sm((double) firstMiss, 2 * minSplit))
            return;

        // Compute values of criteria for all possible split
        // indices.
        defaultEnt = infoGainCrit.oldEnt(m_distribution);
        while (next < firstMiss) {

            if (trainInstances.instance(next - 1).value(m_attIndex) + 1e-5 < trainInstances.instance(next)
                    .value(m_attIndex)) {

                // Move class values for all Instances up to next
                // possible split point.
                m_distribution.shiftRange(1, 0, trainInstances, last, next);

                // Check if enough Instances in each subset and compute
                // values for criteria.
                if (Utils.grOrEq(m_distribution.perBag(0), minSplit)
                        && Utils.grOrEq(m_distribution.perBag(1), minSplit)) {
                    currentInfoGain = infoGainCrit.splitCritValue1(m_distribution, m_sumOfWeights, defaultEnt,
                            rrrrr);
                    if (Utils.gr(currentInfoGain, m_infoGain)) {
                        m_infoGain = currentInfoGain;
                        splitIndex = next - 1;
                    }
                    m_index++;
                }
                last = next;
            }
            next++;
        }

        // Was there any useful split?
        if (m_index == 0)
            return;

        // Compute modified information gain for best split.
        m_infoGain = m_infoGain - (Utils.log2(m_index) / m_sumOfWeights);
        if (Utils.smOrEq(m_infoGain, 0))
            return;

        // Set instance variables' values to values for
        // best split.
        m_numSubsets = 2;
        m_splitPoint = (trainInstances.instance(splitIndex + 1).value(m_attIndex)
                + trainInstances.instance(splitIndex).value(m_attIndex)) / 2;

        // In case we have a numerical precision problem we need to choose the
        // smaller value
        if (m_splitPoint == trainInstances.instance(splitIndex + 1).value(m_attIndex)) {
            m_splitPoint = trainInstances.instance(splitIndex).value(m_attIndex);
        }

        // Restore distributioN for best split.
        m_distribution = new Distribution(2, trainInstances.numClasses());
        m_distribution.addRange(0, trainInstances, 0, splitIndex + 1);
        m_distribution.addRange(1, trainInstances, splitIndex + 1, firstMiss);

        // Compute modified gain ratio for best split.
        m_gainRatio = gainRatioCrit.splitCritValue1(m_distribution, m_sumOfWeights, m_infoGain, lllll);
    }

    /**
     * Returns (C4.5-type) information gain for the generated split.
     */
    public final double infoGain() {

        return m_infoGain;
    }

    /**
     * Prints left side of condition..
     * 
     * @param data
     *            training set.
     */
    public final String leftSide(Instances data) {

        return data.attribute(m_attIndex).name();
    }

    /**
     * Prints the condition satisfied by instances in a subset.
     * 
     * @param index
     *            of subset
     * @param data
     *            training set.
     */
    public final String rightSide(int index, Instances data) {

        StringBuffer text;

        text = new StringBuffer();
        if (data.attribute(m_attIndex).isNominal())
            text.append(" = " + data.attribute(m_attIndex).value(index));
        else if (index == 0)
            text.append(" <= " + Utils.doubleToString(m_splitPoint, 6));
        else
            text.append(" > " + Utils.doubleToString(m_splitPoint, 6));
        return text.toString();
    }

    /**
     * Returns a string containing java source code equivalent to the test made
     * at this node. The instance being tested is called "i".
     * 
     * @param index
     *            index of the nominal value tested
     * @param data
     *            the data containing instance structure info
     * @return a value of type 'String'
     */
    public final String sourceExpression(int index, Instances data) {

        StringBuffer expr = null;
        if (index < 0) {
            return "i[" + m_attIndex + "] == null";
        }
        if (data.attribute(m_attIndex).isNominal()) {
            expr = new StringBuffer("i[");
            expr.append(m_attIndex).append("]");
            expr.append(".equals(\"").append(data.attribute(m_attIndex).value(index)).append("\")");
        } else {
            expr = new StringBuffer("((Double) i[");
            expr.append(m_attIndex).append("])");
            if (index == 0) {
                expr.append(".doubleValue() <= ").append(m_splitPoint);
            } else {
                expr.append(".doubleValue() > ").append(m_splitPoint);
            }
        }
        return expr.toString();
    }

    /**
     * Sets split point to greatest value in given data smaller or equal to old
     * split point. (C4.5 does this for some strange reason).
     */
    public final void setSplitPoint(Instances allInstances) {

        double newSplitPoint = -Double.MAX_VALUE;
        double tempValue;
        Instance instance;

        if ((allInstances.attribute(m_attIndex).isNumeric()) && (m_numSubsets > 1)) {
            Enumeration enu = allInstances.enumerateInstances();
            while (enu.hasMoreElements()) {
                instance = (Instance) enu.nextElement();
                if (!instance.isMissing(m_attIndex)) {
                    tempValue = instance.value(m_attIndex);
                    if (Utils.gr(tempValue, newSplitPoint) && Utils.smOrEq(tempValue, m_splitPoint))
                        newSplitPoint = tempValue;
                }
            }
            m_splitPoint = newSplitPoint;
        }
    }

    /**
     * Returns the minsAndMaxs of the index.th subset.
     */
    public final double[][] minsAndMaxs(Instances data, double[][] minsAndMaxs, int index) {

        double[][] newMinsAndMaxs = new double[data.numAttributes()][2];

        for (int i = 0; i < data.numAttributes(); i++) {
            newMinsAndMaxs[i][0] = minsAndMaxs[i][0];
            newMinsAndMaxs[i][1] = minsAndMaxs[i][1];
            if (i == m_attIndex)
                if (data.attribute(m_attIndex).isNominal())
                    newMinsAndMaxs[m_attIndex][1] = 1;
                else
                    newMinsAndMaxs[m_attIndex][1 - index] = m_splitPoint;
        }

        return newMinsAndMaxs;
    }

    /**
     * Sets distribution associated with model.
     */
    public void resetDistribution(Instances data) throws Exception {

        Instances insts = new Instances(data, data.numInstances());
        for (int i = 0; i < data.numInstances(); i++) {
            if (whichSubset(data.instance(i)) > -1) {
                insts.add(data.instance(i));
            }
        }
        Distribution newD = new Distribution(insts, this);
        newD.addInstWithUnknown(data, m_attIndex);
        m_distribution = newD;
    }

    /**
     * Returns weights if instance is assigned to more than one subset. Returns
     * null if instance is only assigned to one subset.
     */
    public final double[] weights(Instance instance) {

        double[] weights;
        int i;

        if (instance.isMissing(m_attIndex)) {
            weights = new double[m_numSubsets];
            for (i = 0; i < m_numSubsets; i++)
                weights[i] = m_distribution.perBag(i) / m_distribution.total();
            return weights;
        } else {
            return null;
        }
    }

    /**
     * Returns index of subset instance is assigned to. Returns -1 if instance
     * is assigned to more than one subset.
     * 
     * @exception Exception
     *                if something goes wrong
     */
    public final int whichSubset(Instance instance) throws Exception {

        if (instance.isMissing(m_attIndex))
            return -1;
        else {
            if (instance.attribute(m_attIndex).isNominal())
                return (int) instance.value(m_attIndex);
            else if (Utils.smOrEq(instance.value(m_attIndex), m_splitPoint))
                return 0;
            else
                return 1;
        }
    }

    /**
     * Returns the revision string.
     * 
     * @return the revision
     */
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.13 $");
    }
}