meka.classifiers.multilabel.BCC.java Source code

Java tutorial

Introduction

Here is the source code for meka.classifiers.multilabel.BCC.java

Source

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

package meka.classifiers.multilabel;

import meka.classifiers.multilabel.cc.CNode;
import meka.core.A;
import meka.core.MatrixUtils;
import meka.core.OptionUtils;
import meka.core.StatUtils;
import mst.Edge;
import mst.EdgeWeightedGraph;
import mst.KruskalMST;
import weka.core.Instances;
import weka.core.Option;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.Utils;

import java.util.*;

/**
 * BCC.java - Bayesian Classifier Chains. 
 * Probably would be more aptly called Bayesian Classifier Tree.
 * Creates a maximum spanning tree based on marginal label dependence; then employs a CC classifier. 
 * The original paper used Naive Bayes as a base classifier, hence the name. 
 * <br>
 * See Zaragoza et al. "Bayesian Classifier Chains for Multi-dimensional Classification. IJCAI 2011.
 * <br>
 * @author   Jesse Read
 * @version June 2013
 */
public class BCC extends CC {

    private static final long serialVersionUID = 585507197229071545L;

    /**
     * Description to display in the GUI.
     * 
     * @return      the description
     */
    @Override
    public String globalInfo() {
        return "Bayesian Classifier Chains (BCC).\n"
                + "Creates a maximum spanning tree based on marginal label dependence. Then employs CC.\n"
                + "For more information see:\n" + getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result;

        result = new TechnicalInformation(Type.INPROCEEDINGS);
        result.setValue(Field.AUTHOR, "Julio H. Zaragoza et al.");
        result.setValue(Field.TITLE, "Bayesian Chain Classifiers for Multidimensional Classification");
        result.setValue(Field.BOOKTITLE, "IJCAI'11: International Joint Conference on Artificial Intelligence.");
        result.setValue(Field.YEAR, "2011");

        return result;
    }

    @Override
    public void buildClassifier(Instances D) throws Exception {
        testCapabilities(D);

        m_R = new Random(getSeed());
        int L = D.classIndex();
        int d = D.numAttributes() - L;

        /*
         * Measure [un]conditional label dependencies (frequencies).
         */
        if (getDebug())
            System.out.println("Get unconditional dependencies ...");
        double CD[][] = null;
        if (m_DependencyType.equals("L")) {
            // New Option
            if (getDebug())
                System.out.println("The 'LEAD' method for finding conditional dependence.");
            CD = StatUtils.LEAD(D, getClassifier(), m_R);
        } else {
            // Old/default Option
            if (getDebug())
                System.out.println("The Frequency method for finding marginal dependence.");
            CD = StatUtils.margDepMatrix(D, m_DependencyType);
        }

        if (getDebug())
            System.out.println(MatrixUtils.toString(CD));

        /*
         * Make a fully connected graph, each edge represents the
         * dependence measured between the pair of labels.
         */
        CD = MatrixUtils.multiply(CD, -1); // because we want a *maximum* spanning tree
        if (getDebug())
            System.out.println("Make a graph ...");
        EdgeWeightedGraph G = new EdgeWeightedGraph((int) L);
        for (int i = 0; i < L; i++) {
            for (int j = i + 1; j < L; j++) {
                Edge e = new Edge(i, j, CD[i][j]);
                G.addEdge(e);
            }
        }

        /*
         * Run an off-the-shelf MST algorithm to get a MST.
         */
        if (getDebug())
            System.out.println("Get an MST ...");
        KruskalMST mst = new KruskalMST(G);

        /*
         * Define graph connections based on the MST.
         */
        int paM[][] = new int[L][L];
        for (Edge e : mst.edges()) {
            int j = e.either();
            int k = e.other(j);
            paM[j][k] = 1;
            paM[k][j] = 1;
            //StdOut.println(e);
        }
        if (getDebug())
            System.out.println(MatrixUtils.toString(paM));

        /*
         *  Turn the DAG into a Tree with the m_Seed-th node as root
         */
        int root = getSeed();
        if (getDebug())
            System.out.println("Make a Tree from Root " + root);
        int paL[][] = new int[L][0];
        int visted[] = new int[L];
        Arrays.fill(visted, -1);
        visted[root] = 0;
        treeify(root, paM, paL, visted);
        if (getDebug()) {
            for (int i = 0; i < L; i++) {
                System.out.println("pa_" + i + " = " + Arrays.toString(paL[i]));
            }
        }
        m_Chain = Utils.sort(visted);
        if (getDebug())
            System.out.println("sequence: " + Arrays.toString(m_Chain));
        /*
        * Bulid a classifier 'tree' based on the Tree
        */
        if (getDebug())
            System.out.println("Build Classifier Tree ...");
        nodes = new CNode[L];
        for (int j : m_Chain) {
            if (getDebug())
                System.out.println("\t node h_" + j + " : P(y_" + j + " | x_[1:" + d + "], y_"
                        + Arrays.toString(paL[j]) + ")");
            nodes[j] = new CNode(j, null, paL[j]);
            nodes[j].build(D, m_Classifier);
        }

        if (getDebug())
            System.out.println(" * DONE * ");

        /* 
        * Notes ...
           paL[j] = new int[]{};            // <-- BR !!
           paL[j] = MLUtils.gen_indices(j); // <-- CC !!
        */
    }

    /**
     * Treeify - make a tree given the structure defined in paM[][], using the root-th node as root.
     */
    private void treeify(int root, int paM[][], int paL[][], int visited[]) {
        int children[] = new int[] {};
        for (int j = 0; j < paM[root].length; j++) {
            if (paM[root][j] == 1) {
                if (visited[j] < 0) {
                    children = A.append(children, j);
                    paL[j] = A.append(paL[j], root);
                    visited[j] = visited[Utils.maxIndex(visited)] + 1;
                }
                // set as visited
                //paM[root][j] = 0;
            }
        }
        // go through again
        for (int child : children) {
            treeify(child, paM, paL, visited);
        }
    }

    /* 
     * TODO: Make a generic abstract -dependency_user- class that has this option, and extend it here
     */

    String m_DependencyType = "Ibf";

    public void setDependencyType(String value) {
        m_DependencyType = value;
    }

    public String getDependencyType() {
        return m_DependencyType;
    }

    public String dependencyTypeTipText() {
        return "XXX";
    }

    @Override
    public Enumeration listOptions() {
        Vector result = new Vector();
        result.addElement(new Option(
                "\tThe way to measure dependencies.\n\tdefault: " + m_DependencyType + " (frequencies only)", "X",
                1, "-X <value>"));
        OptionUtils.add(result, super.listOptions());
        return OptionUtils.toEnumeration(result);
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        setDependencyType(OptionUtils.parse(options, 'X', "Ibf"));
        super.setOptions(options);
    }

    @Override
    public String[] getOptions() {
        List<String> result = new ArrayList<>();
        OptionUtils.add(result, 'X', getDependencyType());
        OptionUtils.add(result, super.getOptions());
        return OptionUtils.toArray(result);
    }

    public static void main(String args[]) {
        ProblemTransformationMethod.evaluation(new BCC(), args);
    }
}