mulan.classifier.transformation.LabelsetPruning.java Source code

Java tutorial

Introduction

Here is the source code for mulan.classifier.transformation.LabelsetPruning.java

Source

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    LabelsetPruning.java
 *    Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
 */
package mulan.classifier.transformation;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;

import mulan.data.LabelSet;
import mulan.data.MultiLabelInstances;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

/**
 * Common functionality class for the PPT and PS algorithms <p>
 *
 * @author Grigorios Tsoumakas 
 * @version June 4, 2010
 */
public abstract class LabelsetPruning extends LabelPowerset {

    /** labelsets and a list with the corresponding instances */
    HashMap<LabelSet, ArrayList<Instance>> ListInstancePerLabel;
    /** parameter for the threshold of number of occurences of a labelset */
    protected int p;
    /** format of the data */
    Instances format;

    /**
     * Constructor that initializes learner with base algorithm and main parameter
     *
     * @param classifier base single-label classification algorithm
     * @param aP number of instances required for a labelset to be included.
     */
    public LabelsetPruning(Classifier classifier, int aP) {
        super(classifier);
        if (aP <= 0) {
            throw new IllegalArgumentException("p should be larger than 0!");
        }
        p = aP;
        setConfidenceCalculationMethod(2);
        setMakePredictionsBasedOnConfidences(true);
        threshold = 0.21;
    }

    abstract ArrayList<Instance> processRejected(LabelSet ls);

    @Override
    protected void buildInternal(MultiLabelInstances mlDataSet) throws Exception {
        Instances data = mlDataSet.getDataSet();
        format = new Instances(data, 0);
        int numInstances = data.numInstances();

        ListInstancePerLabel = new HashMap<LabelSet, ArrayList<Instance>>();
        for (int i = 0; i < numInstances; i++) {
            double[] dblLabels = new double[numLabels];
            for (int j = 0; j < numLabels; j++) {
                int index = labelIndices[j];
                double value = Double.parseDouble(data.attribute(index).value((int) data.instance(i).value(index)));
                dblLabels[j] = value;
            }
            LabelSet labelSet = new LabelSet(dblLabels);
            if (ListInstancePerLabel.containsKey(labelSet)) {
                ListInstancePerLabel.get(labelSet).add(data.instance(i));
            } else {
                ArrayList<Instance> li = new ArrayList<Instance>();
                li.add(data.instance(i));
                ListInstancePerLabel.put(labelSet, li);
            }
        }

        // Iterates the structure and a) if occurences of a labelset are higher
        // than p parameter then add them to the training set, b) if occurences
        // are less, then depending on the strategy discard/reintroduce them
        Instances newData = new Instances(data, 0);
        Iterator<LabelSet> it = ListInstancePerLabel.keySet().iterator();
        while (it.hasNext()) {
            LabelSet ls = it.next();
            ArrayList<Instance> instances = ListInstancePerLabel.get(ls);
            if (instances.size() > p) {
                for (int i = 0; i < instances.size(); i++) {
                    newData.add(instances.get(i));
                }
            } else {
                ArrayList<Instance> processed = processRejected(ls);
                newData.addAll(processed);
            }
        }

        super.buildInternal(new MultiLabelInstances(newData, mlDataSet.getLabelsMetaData()));
    }
}