mulan.classifier.meta.HierarchyBuilder.java Source code

Java tutorial

Introduction

Here is the source code for mulan.classifier.meta.HierarchyBuilder.java

Source

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    HierarchyBuilder.java
 *    Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
 */
package mulan.classifier.meta;

import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.transform.OutputKeys;
import javax.xml.transform.Source;
import javax.xml.transform.Transformer;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.dom.DOMSource;
import javax.xml.transform.stream.StreamResult;

import mulan.data.InvalidDataFormatException;
import mulan.data.LabelNode;
import mulan.data.LabelNodeImpl;
import mulan.data.LabelsMetaData;
import mulan.data.LabelsMetaDataImpl;
import mulan.data.MultiLabelInstances;
import mulan.data.DataUtils;

import org.w3c.dom.Document;
import org.w3c.dom.Element;

import weka.clusterers.EM;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffSaver;

/**
 * Class that builds a hierarchy on flat lables of given mulltilabel data.
 * The hierarchy may be built with three methods.
 *
 * @author George Saridis
 * @author Grigorios Tsoumakas
 * @version 0.1
 */
public class HierarchyBuilder implements Serializable {

    private int numPartitions;
    private Document labelsXMLDoc;
    private Method method;

    public HierarchyBuilder(int partitions, Method method) {
        numPartitions = partitions;
        this.method = method;
    }

    /**
     * Builds a hierarhical multi-label dataset. Firstly a random hierarchy is
     * built on top of the labels of a flat multi-label dataset, by recursively
     * randomly partitioning the labels into a specified number of clusters.
     * Then the values for the new "meta-labels" are properly set, so that
     * the hierarchy is respected.
     *
     * @param mlData the multiLabel data on which the new hierarchy will be built
     * @return the new multiLabel data
     * @throws java.lang.Exception
     */
    public MultiLabelInstances buildHierarchy(MultiLabelInstances mlData) throws Exception {
        LabelsMetaData labelsMetaData = buildLabelHierarchy(mlData);
        return HierarchyBuilder.createHierarchicalDataset(mlData, labelsMetaData);
    }

    /**
     * Builds a hierarhy of labels on top of the labels of a flat multi-label
     * dataset, by recursively partitioning the labels into a specified number
     * of partitions.
     *
     * @param mlData the multiLabel data on with the new hierarchy will be built
     * @return a hierarchy of labels
     * @throws java.lang.Exception
     */
    public LabelsMetaData buildLabelHierarchy(MultiLabelInstances mlData) throws Exception {
        if (numPartitions > mlData.getNumLabels()) {
            throw new IllegalArgumentException("Number of labels is smaller than the number of partitions");
        }

        Set<String> setOfLabels = mlData.getLabelsMetaData().getLabelNames();
        List<String> listOfLabels = new ArrayList<String>();
        for (String label : setOfLabels) {
            listOfLabels.add(label);
        }

        ArrayList<String>[] childrenLabels = null;
        switch (method) {
        case Random:
            childrenLabels = randomPartitioning(numPartitions, listOfLabels);
            break;
        case Clustering:
            childrenLabels = clustering(numPartitions, listOfLabels, mlData, false);
            break;
        case BalancedClustering:
            childrenLabels = clustering(numPartitions, listOfLabels, mlData, true);
            break;
        }

        for (int i = 0; i < numPartitions; i++) {
            if (childrenLabels[i].size() == listOfLabels.size()) {
                // another idea is to add leaves here
                childrenLabels = randomPartitioning(numPartitions, listOfLabels);
                break;
            }
        }

        LabelsMetaDataImpl metaData = new LabelsMetaDataImpl();
        for (int i = 0; i < numPartitions; i++) {
            if (childrenLabels[i].size() == 0) {
                continue;
            }
            if (childrenLabels[i].size() == 1) {
                metaData.addRootNode(new LabelNodeImpl(childrenLabels[i].get(0)));
                continue;
            }
            if (childrenLabels[i].size() > 1) {
                LabelNodeImpl metaLabel = new LabelNodeImpl("MetaLabel " + (i + 1));
                createLabelsMetaDataRecursive(metaLabel, childrenLabels[i], mlData);
                metaData.addRootNode(metaLabel);
            }
        }

        return metaData;
    }

    public MultiLabelInstances buildHierarchyAndSaveFiles(MultiLabelInstances mlData, String arffName,
            String xmlName) throws Exception {
        MultiLabelInstances newData = buildHierarchy(mlData);
        saveToArffFile(newData.getDataSet(), new File(arffName));
        createXMLFile(mlData.getLabelsMetaData());
        saveToXMLFile(xmlName);
        return newData;
    }

    private void createLabelsMetaDataRecursive(LabelNodeImpl node, List<String> labels,
            MultiLabelInstances mlData) {
        if (labels.size() <= numPartitions) {
            for (int i = 0; i < labels.size(); i++) {
                LabelNodeImpl child = new LabelNodeImpl(labels.get(i));
                node.addChildNode(child);
            }
            return;
        }

        ArrayList<String>[] childrenLabels = null;
        switch (method) {
        case Random:
            childrenLabels = randomPartitioning(numPartitions, labels);
            break;
        case Clustering:
            childrenLabels = clustering(numPartitions, labels, mlData, false);
            break;
        case BalancedClustering:
            childrenLabels = clustering(numPartitions, labels, mlData, true);
            break;
        }

        for (int i = 0; i < numPartitions; i++) {
            if (childrenLabels[i].size() == labels.size()) {
                // another idea is to add leaves here
                childrenLabels = randomPartitioning(numPartitions, labels);
                break;
            }
        }

        for (int i = 0; i < numPartitions; i++) {
            if (childrenLabels[i].size() == 0) {
                continue;
            }
            if (childrenLabels[i].size() == 1) {
                LabelNodeImpl child = new LabelNodeImpl(childrenLabels[i].get(0));
                node.addChildNode(child);
                continue;
            }
            if (childrenLabels[i].size() > 1) {
                LabelNodeImpl child = new LabelNodeImpl(node.getName() + "." + (i + 1));
                node.addChildNode(child);
                createLabelsMetaDataRecursive(child, childrenLabels[i], mlData);
            }
        }
    }

    private ArrayList<String>[] clustering(int clusters, List<String> labels, MultiLabelInstances mlData,
            boolean balanced) {
        ArrayList<String>[] childrenLabels = new ArrayList[clusters];
        for (int i = 0; i < clusters; i++) {
            childrenLabels[i] = new ArrayList<String>();
        }

        // transpose data and keep only labels in the parameter list
        int numInstances = mlData.getDataSet().numInstances();
        ArrayList<Attribute> attInfo = new ArrayList<Attribute>(numInstances);
        for (int i = 0; i < numInstances; i++) {
            Attribute att = new Attribute("instance" + (i + 1));
            attInfo.add(att);
        }
        System.out.println("constructing instances");
        Instances transposed = new Instances("transposed", attInfo, 0);
        for (int i = 0; i < labels.size(); i++) {
            double[] values = new double[numInstances];
            for (int j = 0; j < numInstances; j++) {
                values[j] = mlData.getDataSet().instance(j).value(mlData.getDataSet().attribute(labels.get(i)));
            }
            Instance newInstance = DataUtils.createInstance(mlData.getDataSet().instance(0), 1, values);
            transposed.add(newInstance);
        }

        if (!balanced) {
            EM clusterer = new EM();
            try {
                // cluster the labels
                clusterer.setNumClusters(clusters);
                System.out.println("clustering");
                clusterer.buildClusterer(transposed);
                // return the clustering
                for (int i = 0; i < labels.size(); i++) {
                    childrenLabels[clusterer.clusterInstance(transposed.instance(i))].add(labels.get(i));
                }
            } catch (Exception ex) {
                Logger.getLogger(HierarchyBuilder.class.getName()).log(Level.SEVERE, null, ex);
            }
        } else {
            ConstrainedKMeans clusterer = new ConstrainedKMeans();
            try {
                // cluster the labels
                clusterer.setMaxIterations(20);
                clusterer.setNumClusters(clusters);
                System.out.println("balanced clustering");
                clusterer.buildClusterer(transposed);
                // return the clustering
                for (int i = 0; i < labels.size(); i++) {
                    childrenLabels[clusterer.clusterInstance(transposed.instance(i))].add(labels.get(i));
                }
            } catch (Exception ex) {
                Logger.getLogger(HierarchyBuilder.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //==================================================12-17
        for (int i = 0; i < childrenLabels.length; i++) {
            System.out.println(childrenLabels[i]);
        }
        //==================================================
        return childrenLabels;
    }

    private ArrayList<String>[] randomPartitioning(int partitions, List<String> labels) {
        ArrayList<String>[] childrenLabels = new ArrayList[partitions];
        for (int i = 0; i < partitions; i++) {
            childrenLabels[i] = new ArrayList<String>();
        }

        Random rnd = new Random();
        while (!labels.isEmpty()) {
            for (int i = 0; i < partitions; i++) {
                if (labels.isEmpty()) {
                    break;
                }
                String rndLabel = labels.remove(rnd.nextInt(labels.size()));
                childrenLabels[i].add(rndLabel);
            }
        }
        return childrenLabels;
    }

    /**
     * Creates the hierarchical dataset according to the original multilabel
     * instances object and the constructed label hierarchy
     *
     * @param mlData the original multilabel instances
     * @param metaData the metadata of the constructed label hierarchy
     * @return the produced dataset
     * @throws InvalidDataFormatException 
     */
    public static MultiLabelInstances createHierarchicalDataset(MultiLabelInstances mlData, LabelsMetaData metaData)
            throws InvalidDataFormatException {
        Set<String> leafLabels = mlData.getLabelsMetaData().getLabelNames();
        Set<String> metaLabels = metaData.getLabelNames();
        for (String string : leafLabels) {
            metaLabels.remove(string);
        }
        Instances dataSet = mlData.getDataSet();
        int numMetaLabels = metaLabels.size();

        // copy existing attributes
        ArrayList<Attribute> atts = new ArrayList<Attribute>(dataSet.numAttributes() + numMetaLabels);
        for (int i = 0; i < dataSet.numAttributes(); i++) {
            atts.add(dataSet.attribute(i));
        }

        ArrayList<String> labelValues = new ArrayList<String>();
        labelValues.add("0");
        labelValues.add("1");

        // add metalabel attributes
        for (String metaLabel : metaLabels) {
            atts.add(new Attribute(metaLabel, labelValues));
        }

        // initialize dataset
        Instances newDataSet = new Instances("hierarchical", atts, dataSet.numInstances());

        // copy features and labels, set metalabels
        for (int i = 0; i < dataSet.numInstances(); i++) {
            //System.out.println("Constructing instance " + (i+1) + "/"  + dataSet.numInstances());
            // initialize new values
            double[] newValues = new double[newDataSet.numAttributes()];
            Arrays.fill(newValues, 0);

            // copy features and labels
            double[] values = dataSet.instance(i).toDoubleArray();
            System.arraycopy(values, 0, newValues, 0, values.length);

            // set metalabels
            for (String label : leafLabels) {
                Attribute att = dataSet.attribute(label);
                if (att.value((int) dataSet.instance(i).value(att)).equals("1")) {
                    //System.out.println(label);
                    //System.out.println(Arrays.toString(metaData.getLabelNames().toArray()));
                    LabelNode currentNode = metaData.getLabelNode(label);
                    // put 1 all the way up to the root, unless you see a 1, in which case stop
                    while (currentNode.hasParent()) {
                        currentNode = currentNode.getParent();
                        Attribute currentAtt = newDataSet.attribute(currentNode.getName());
                        // change the following to refer to the array
                        if (newValues[atts.indexOf(currentAtt)] == 1) // no need to go more up
                        {
                            break;
                        } else // put 1
                        {
                            newValues[atts.indexOf(currentAtt)] = 1;
                        }
                    }
                }
            }
            Instance instance = dataSet.instance(i);
            newDataSet.add(DataUtils.createInstance(instance, instance.weight(), newValues));
        }
        return new MultiLabelInstances(newDataSet, metaData);
    }

    private void saveToArffFile(Instances dataSet, File file) throws IOException {
        ArffSaver saver = new ArffSaver();
        saver.setInstances(dataSet);
        saver.setFile(file);
        saver.writeBatch();
    }

    private void createXMLFile(LabelsMetaData metaData) throws Exception {
        DocumentBuilderFactory docBF = DocumentBuilderFactory.newInstance();
        DocumentBuilder docBuilder = docBF.newDocumentBuilder();
        labelsXMLDoc = docBuilder.newDocument();

        Element rootElement = labelsXMLDoc.createElement("labels");
        rootElement.setAttribute("xmlns", "http://mulan.sourceforge.net/labels");
        labelsXMLDoc.appendChild(rootElement);
        for (LabelNode rootLabel : metaData.getRootLabels()) {
            Element newLabelElem = labelsXMLDoc.createElement("label");
            newLabelElem.setAttribute("name", rootLabel.getName());
            appendElement(newLabelElem, rootLabel);
            rootElement.appendChild(newLabelElem);
        }
    }

    private void saveToXMLFile(String fileName) {
        Source source = new DOMSource(labelsXMLDoc);
        File xmlFile = new File(fileName);
        StreamResult result = new StreamResult(xmlFile);
        try {
            Transformer transformer = TransformerFactory.newInstance().newTransformer();
            transformer.setOutputProperty(OutputKeys.INDENT, "yes");
            transformer.setOutputProperty("{http://xml.apache.org/xslt}indent-amount", "4");
            transformer.setOutputProperty(OutputKeys.METHOD, "xml");
            transformer.transform(source, result);
        } catch (Exception e) {
            e.printStackTrace();
        }

    }

    private void appendElement(Element labelElem, LabelNode labelNode) {
        for (LabelNode childNode : labelNode.getChildren()) {
            Element newLabelElem = labelsXMLDoc.createElement("label");
            newLabelElem.setAttribute("name", childNode.getName());
            appendElement(newLabelElem, childNode);
            labelElem.appendChild(newLabelElem);
        }
    }

    public enum Method {

        Random, Clustering, BalancedClustering
    }
}