mulan.classifier.meta.thresholding.MetaLabeler.java Source code

Java tutorial

Introduction

Here is the source code for mulan.classifier.meta.thresholding.MetaLabeler.java

Source

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    MetaLabeler.java
 *    Copyright (C) 2009 Aristotle University of Thessaloniki, Thessaloniki, Greece
 */
package mulan.classifier.meta.thresholding;

import java.util.ArrayList;
import java.util.Set;
import java.util.TreeSet;
import java.util.logging.Level;
import java.util.logging.Logger;

import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.data.DataUtils;
import mulan.data.MultiLabelInstances;
import mulan.transformations.RemoveAllLabels;

import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;

/**
 *
 * @author Marios Ioannou
 * @author George Sakkas
 * @author Grigorios Tsoumakas
 * @version 2010.12.14
 */
public class MetaLabeler extends Meta {

    /** the type of the class*/
    private String classChoice;

    /**
     * Constructor that initializes the learner
     *
     * @param baseLearner the underlying multi-label learner
     * @param classifier the binary classification
     * @param metaDataChoice the type of meta-data
     * @param aClassChoice the type of the class
     */
    public MetaLabeler(MultiLabelLearner baseLearner, Classifier classifier, String metaDataChoice,
            String aClassChoice) {
        super(baseLearner, classifier, metaDataChoice);
        if (!metaDataChoice.equals("Content-Based")) {
            try {
                foldLearner = baseLearner.makeCopy();
            } catch (Exception ex) {
                Logger.getLogger(MetaLabeler.class.getName()).log(Level.SEVERE, null, ex);
            }
            kFoldsCV = 3;
        }
        classChoice = aClassChoice;
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(Type.INPROCEEDINGS);
        result.setValue(Field.AUTHOR, "Lei Tang and Sugu Rajan and Yijay K. Narayanan");
        result.setValue(Field.TITLE, "Large scale multi-label classification via metalabeler");
        result.setValue(Field.BOOKTITLE, "Proceedings of the 18th international conference on World wide web ");
        result.setValue(Field.PAGES, "211-220");
        result.setValue(Field.LOCATION, "Madrid, Spain");
        result.setValue(Field.YEAR, "2009");
        return result;
    }

    @Override
    protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception {
        //System.out.println(instance);
        MultiLabelOutput mlo = baseLearner.makePrediction(instance);
        int[] arrayOfRankink = new int[numLabels];
        boolean[] predictedLabels = new boolean[numLabels];
        Instance modifiedIns = modifiedInstanceX(instance, metaDatasetChoice);
        //System.out.println(modifiedIns);
        modifiedIns.insertAttributeAt(modifiedIns.numAttributes());
        // set dataset to instance
        modifiedIns.setDataset(classifierInstances);
        //get the bipartition_key after classify the instance
        int bipartition_key;
        if (classChoice.compareTo("Nominal-Class") == 0) {
            double classify_key = classifier.classifyInstance(modifiedIns);
            String s = classifierInstances.attribute(classifierInstances.numAttributes() - 1)
                    .value((int) classify_key);
            bipartition_key = Integer.valueOf(s);
        } else { //Numeric-Class
            double classify_key = classifier.classifyInstance(modifiedIns);
            bipartition_key = (int) Math.round(classify_key);
        }
        if (mlo.hasRanking()) {
            arrayOfRankink = mlo.getRanking();
            for (int i = 0; i < numLabels; i++) {
                if (arrayOfRankink[i] <= bipartition_key) {
                    predictedLabels[i] = true;
                } else {
                    predictedLabels[i] = false;
                }
            }
        }
        MultiLabelOutput final_mlo = new MultiLabelOutput(predictedLabels, mlo.getConfidences());
        return final_mlo;
    }

    private int countTrueLabels(Instance instance) {
        int numTrueLabels = 0;
        for (int i = 0; i < numLabels; i++) {
            int labelIndice = labelIndices[i];
            if (instance.dataset().attribute(labelIndice).value((int) instance.value(labelIndice)).equals("1")) {
                numTrueLabels++;
            }
        }
        return numTrueLabels;
    }

    protected Instances transformData(MultiLabelInstances trainingData) throws Exception {
        // initialize  classifier instances
        classifierInstances = RemoveAllLabels.transformInstances(trainingData);
        classifierInstances = new Instances(classifierInstances, 0);
        Attribute target = null;
        if (classChoice.equals("Nominal-Class")) {
            int countTrueLabels = 0;
            Set<Integer> treeSet = new TreeSet();
            for (int instanceIndex = 0; instanceIndex < trainingData.getDataSet().numInstances(); instanceIndex++) {
                countTrueLabels = 0;
                for (int i = 0; i < numLabels; i++) {
                    int labelIndice = labelIndices[i];
                    if (trainingData.getDataSet().attribute(labelIndice)
                            .value((int) trainingData.getDataSet().instance(instanceIndex).value(labelIndice))
                            .equals("1")) {
                        countTrueLabels++;
                    }
                }
                treeSet.add(countTrueLabels);
            }
            ArrayList<String> classlabel = new ArrayList<String>();
            for (Integer x : treeSet) {
                classlabel.add(x.toString());
            }
            target = new Attribute("Class", classlabel);
        } else if (classChoice.equals("Numeric-Class")) {
            target = new Attribute("Class");
        }
        classifierInstances.insertAttributeAt(target, classifierInstances.numAttributes());
        classifierInstances.setClassIndex(classifierInstances.numAttributes() - 1);

        // create instances
        if (metaDatasetChoice.equals("Content-Based")) {
            for (int instanceIndex = 0; instanceIndex < trainingData.getNumInstances(); instanceIndex++) {
                Instance instance = trainingData.getDataSet().instance(instanceIndex);
                double[] values = instance.toDoubleArray();
                double[] newValues = new double[classifierInstances.numAttributes()];
                for (int i = 0; i < featureIndices.length; i++) {
                    newValues[i] = values[featureIndices[i]];
                }

                //set the number of true labels of an instance
                int numTrueLabels = countTrueLabels(instance);
                if (classChoice.compareTo("Nominal-Class") == 0) {
                    newValues[newValues.length - 1] = classifierInstances
                            .attribute(classifierInstances.numAttributes() - 1).indexOfValue("" + numTrueLabels);
                } else if (classChoice.compareTo("Numeric-Class") == 0) {
                    newValues[newValues.length - 1] = numTrueLabels;
                }
                Instance newInstance = DataUtils.createInstance(instance, instance.weight(), newValues);
                classifierInstances.add(newInstance);
            }
        } else {
            for (int k = 0; k < kFoldsCV; k++) {
                //Split data to train and test sets
                MultiLabelLearner tempLearner;
                MultiLabelInstances mlTest;
                if (kFoldsCV == 1) {
                    tempLearner = baseLearner;
                    mlTest = trainingData;
                } else {
                    Instances train = trainingData.getDataSet().trainCV(kFoldsCV, k);
                    Instances test = trainingData.getDataSet().testCV(kFoldsCV, k);
                    MultiLabelInstances mlTrain = new MultiLabelInstances(train, trainingData.getLabelsMetaData());
                    mlTest = new MultiLabelInstances(test, trainingData.getLabelsMetaData());
                    tempLearner = foldLearner.makeCopy();
                    tempLearner.build(mlTrain);
                }

                // copy features and labels, set metalabels
                for (int instanceIndex = 0; instanceIndex < mlTest.getDataSet().numInstances(); instanceIndex++) {
                    Instance instance = mlTest.getDataSet().instance(instanceIndex);

                    // initialize new class values
                    double[] newValues = new double[classifierInstances.numAttributes()];

                    // create features
                    valuesX(tempLearner, instance, newValues, metaDatasetChoice);

                    //set the number of true labels of an instance   
                    int numTrueLabels = countTrueLabels(instance);
                    if (classChoice.compareTo("Nominal-Class") == 0) {
                        newValues[newValues.length - 1] = classifierInstances
                                .attribute(classifierInstances.numAttributes() - 1)
                                .indexOfValue("" + numTrueLabels);
                    } else if (classChoice.compareTo("Numeric-Class") == 0) {
                        newValues[newValues.length - 1] = numTrueLabels;
                    }

                    // add the new instance to  classifierInstances
                    Instance newInstance = DataUtils.createInstance(mlTest.getDataSet().instance(instanceIndex),
                            mlTest.getDataSet().instance(instanceIndex).weight(), newValues);
                    classifierInstances.add(newInstance);
                }
            }
        }

        return classifierInstances;
    }

    /**
     * Sets the number of folds for internal cv
     *
     * @param f the number of folds
     */
    public void setFolds(int f) {
        kFoldsCV = f;
    }
}