fantail.algorithms.RankingWithBinaryPCT.java Source code

Java tutorial

Introduction

Here is the source code for fantail.algorithms.RankingWithBinaryPCT.java

Source

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

package fantail.algorithms;

import fantail.core.Tools;
import java.util.Arrays;
import java.util.Random;
import org.apache.commons.math.stat.descriptive.DescriptiveStatistics;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;

/**
 *
 * @author Quan Sun quan.sun.nz@gmail.com
 */
public class RankingWithBinaryPCT extends AbstractRanker {

    private RankingWithBinaryPCT[] m_Successors;
    private Attribute m_Attribute;
    private double[] m_Prototype;
    private double m_SplitPoint;
    private int m_MiniLeaf = 2;
    private double m_MinVariancea = 0.001;
    private int m_K = 0;
    private int m_Seed = 1;
    private boolean m_UseMedian = true;
    private int m_MaxDepth = 0;
    private int m_NumTargetLabels;

    public void setUseMedian(boolean m) {
        m_UseMedian = m;
    }

    @Override
    public String rankerName() {
        return "BinaryPCT";
    }

    public void setNumTargetLabels(int t) {
        m_NumTargetLabels = t;
    }

    @Override
    public void buildRanker(Instances data) throws Exception {

        setNumTargetLabels(Tools.getNumberTargets(data));
        Random rnd = new Random(m_Seed);
        Instances workingData = new Instances(data);

        int depth = 1;
        makeTree(workingData, rnd, depth);
    }

    private double computeVariance(Instances data) throws Exception {
        double[][] targets = new double[data.numInstances()][];
        for (int i = 0; i < data.numInstances(); i++) {
            targets[i] = Tools.getTargetVector(data.instance(i));
        }
        double sumVar = 0;
        for (int i = 0; i < m_NumTargetLabels; i++) {
            double[] target_i = new double[data.numInstances()];

            for (int j = 0; j < data.numInstances(); j++) {
                Instance metaInst = (Instance) data.instance(j);
                target_i[j] = targets[j][i] * metaInst.weight();
            }
            sumVar += weka.core.Utils.variance(target_i);
        }
        return sumVar / m_NumTargetLabels;
    }

    private double computeVarianceReduction(Instances data, int attIndex, double splitPoint) throws Exception {
        //double varianceaE = computeVariance(data); // doesn't make sense to compute this
        Instances[] P = splitData(data, attIndex, splitPoint);
        double variancePk = 0;
        for (int k = 0; k < P.length; k++) {
            variancePk += (1.0 * P[k].numInstances() / data.numInstances() * computeVariance(P[k]));
        }
        return -variancePk;
    }

    private Instances[] splitData(Instances data, int attIndex, double splitPoint) throws Exception {
        Instances[] subsets = new Instances[2];
        subsets[0] = new Instances(data, 0);
        subsets[1] = new Instances(data, 0);

        for (int i = 0; i < data.numInstances(); i++) {
            Instance inst = data.instance(i);
            if (inst.value(attIndex) <= splitPoint && (subsets[0].numInstances() <= 0.5 * data.numInstances())) {
                subsets[0].add(inst);
            } else {
                subsets[1].add(inst);
            }
        }
        // TODO: 
        if (subsets[1].numInstances() == 0) {
            subsets[1].add(subsets[0].instance(0));
        }
        if (subsets[0].numInstances() == 0) {
            subsets[0].add(subsets[1].instance(0));
        }
        return subsets;
    }

    private Instances[] splitData2(Instances data, int attIndex, double splitPoint) throws Exception {
        Instances[] subsets = new Instances[2];
        subsets[0] = new Instances(data, 0);
        subsets[1] = new Instances(data, 0);
        for (int i = 0; i < data.numInstances(); i++) {
            Instance inst = data.instance(i);
            if (inst.value(attIndex) <= splitPoint) {
                subsets[0].add(inst);
            } else {
                subsets[1].add(inst);
            }
        }
        // TODO: 
        if (subsets[1].numInstances() == 0) {
            subsets[1].add(subsets[0].instance(0));
        }
        if (subsets[0].numInstances() == 0) {
            subsets[0].add(subsets[1].instance(0));
        }
        return subsets;
    }

    private double getMedian(Instances data, int attIndex) throws Exception {
        DescriptiveStatistics stats = new DescriptiveStatistics();
        for (int i = 0; i < data.numInstances(); i++) {
            Instance inst = (Instance) data.instance(i);
            stats.addValue(inst.value(attIndex));
        }
        double median = stats.getPercentile(50);
        return median;
    }

    private int getRandomPosition(Random r, int[] attIndice) {
        int randP = -1;
        int index = -1;
        while (randP == -1) {
            index = r.nextInt(attIndice.length);
            randP = attIndice[index];

            if (randP != -1) {
                return index;
            }
        }
        return index;
    }

    private void makeTree(Instances data, java.util.Random r, int depth) throws Exception {

        if (data.numInstances() <= m_MiniLeaf || (depth >= m_MaxDepth && m_MaxDepth != 0)
                || computeVariance(data) <= m_MinVariancea) {
            //|| maxVarianceaReduction <= 0
            //|| data.variance(bestAttIndex) <= 0) { // || data.variance(bestAttIndex) <= 0 ) {   copied from ART, 

            m_Attribute = null;
            m_Prototype = AbstractRanker.getAvgRanking(data);
            return;
        }

        //
        if (m_K > data.numAttributes()) {
            m_K = data.numAttributes();
        }
        if (m_K < 1) {
            m_K = (int) weka.core.Utils.log2(data.numAttributes()) + 1;
        }

        // TODO:
        int[] attIndice = new int[data.numAttributes() - 1];
        for (int i = 0; i < attIndice.length; i++) {
            attIndice[i] = i;
        }
        for (int i = 0; i < attIndice.length; i++) {
            //int randomPosition = getRandomPosition(r, attIndice);
            int randomPosition = r.nextInt(attIndice.length);
            int temp = attIndice[i];
            attIndice[i] = attIndice[randomPosition];
            attIndice[randomPosition] = temp;
        }

        RankingWithBinaryPCT.AttScorePair[] attScorePair = new RankingWithBinaryPCT.AttScorePair[m_K];

        for (int i = 0; i < m_K; i++) {
            int attIndex = attIndice[i];

            double splitPoint = Double.NaN;
            if (!m_UseMedian) {
                splitPoint = data.meanOrMode(attIndex);
            } else {
                splitPoint = getMedian(data, attIndex);
            }
            double varianceReduction = computeVarianceReduction(data, attIndex, splitPoint);
            attScorePair[i] = new RankingWithBinaryPCT.AttScorePair(attIndex, varianceReduction);

        }

        Arrays.sort(attScorePair);
        int randAttIndex = 0;
        int bestAttIndex = attScorePair[randAttIndex].index;

        double maxVarianceaReduction = attScorePair[randAttIndex].score;

        //        if (data.numInstances() <= 1 * m_MiniLeaf
        //                || (depth >= m_MaxDepth && m_MaxDepth != 0)
        //                || computeVariance(data) <= m_MinVariancea) {
        //                //|| maxVarianceaReduction <= 0
        //                //|| data.variance(bestAttIndex) <= 0) { // || data.variance(bestAttIndex) <= 0 ) {   copied from ART, 
        //
        //            m_Attribute = null;
        //            m_Prototype = AbstractRanker.getAvgRanking(data);
        //            return;
        //        }
        m_Attribute = data.attribute(bestAttIndex);

        if (!m_UseMedian) {
            m_SplitPoint = data.meanOrMode(bestAttIndex);
        } else {
            m_SplitPoint = getMedian(data, bestAttIndex);
        }

        //m_SplitPoint = data.meanOrMode(m_Attribute);
        Instances[] splitData = splitData(data, bestAttIndex, m_SplitPoint);

        //System.out.println(splitData[0].numInstances());
        //System.out.println(splitData[1].numInstances());
        //System.out.println();

        m_Successors = new RankingWithBinaryPCT[2];

        for (int j = 0; j < 2; j++) {
            m_Successors[j] = new RankingWithBinaryPCT();
            m_Successors[j].setMiniLeaf(m_MiniLeaf);
            m_Successors[j].setK(m_K);
            m_Successors[j].setUseMedian(m_UseMedian);
            m_Successors[j].setNumTargetLabels(m_NumTargetLabels);
            m_Successors[j].makeTree(splitData[j], r, depth + 1);
        }
    }

    public void setMiniLeaf(int n) {
        if (n < 1) {
            m_MiniLeaf = 1;
        }
        m_MiniLeaf = n;
    }

    public int getMiniLeaf() {
        return m_MiniLeaf;
    }

    public void setK(int n) {
        m_K = n;
    }

    public int getK() {
        return m_K;
    }

    public void setRandomSeed(int s) {
        m_Seed = s;
    }

    @Override
    public double[] recommendRanking(Instance metaInst) throws Exception {
        if (m_Attribute == null) {
            return m_Prototype;
        } else {
            if (metaInst.value(m_Attribute) <= m_SplitPoint) {
                return m_Successors[0].recommendRanking(metaInst);
            } else {
                return m_Successors[1].recommendRanking(metaInst);
            }
        }
    }

    private class AttScorePair implements Comparable {

        double score = Double.NaN;
        int index = -1;

        public AttScorePair(int i, double s) {
            score = s;
            index = i;
        }

        @Override
        public int compareTo(Object o) {
            RankingWithBinaryPCT.AttScorePair other = (RankingWithBinaryPCT.AttScorePair) o;
            if (other.score > score) {
                return 1;
            } else if (other.score == score) {
                return 0;
            } else {
                return -1;
            }
        }
    }
}