machinelearningq2.ExtendedNaiveBayes.java Source code

Java tutorial

Introduction

Here is the source code for machinelearningq2.ExtendedNaiveBayes.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package machinelearningq2;

import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import weka.classifiers.Classifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Discretize;

/**
 *
 * @author Luke
 */
public class ExtendedNaiveBayes extends BasicNaiveBayesV1 {

    private int[] classValueCounts;

    private double[][] attributeMeans;
    private double[][] attributeVariance;
    private ArrayList<DataFound> data = new ArrayList<>();
    private double countData;

    private double[] binCount;
    private double testCount = 0;
    private double correctCount = 0;
    private final String gausianOrDiscretise;

    public ExtendedNaiveBayes(boolean laplace, String gausianOrDiscretise) {
        super(laplace);
        this.gausianOrDiscretise = gausianOrDiscretise;
    }

    /**
     *
     * Build classifier will either build a gaussian or a discrete classifier
     * dependent on user input
     *
     * @param ins
     * @throws Exception
     */
    @Override
    public void buildClassifier(Instances ins) throws Exception {
        if ("d".equals(gausianOrDiscretise)) {
            buildDiscreteClassifier(ins);
        } else {
            countData = ins.size();
            // assigns the class position of the instance 
            ins.setClassIndex(ins.numAttributes() - 1);
            classValueCounts = new int[ins.numClasses()];
            attributeMeans = new double[ins.numClasses()][ins.numAttributes() - 1];
            attributeVariance = new double[ins.numClasses()][ins.numAttributes() - 1];

            // store the values
            for (Instance line : ins) {
                double classValue = line.classValue();
                classValueCounts[(int) classValue]++;
                for (int i = 0; i < line.numAttributes() - 1; i++) {
                    double attributeValue = line.value(i);
                    attributeMeans[(int) classValue][i] += attributeValue;
                    DataFound d = new DataFound(attributeValue, classValue, i);

                    int index = data.indexOf(d);
                    // then it doesn't exist
                    if (index == -1) {
                        data.add(d);
                    } else {
                        data.get(index).incrementCount();
                    }
                }
            }
            System.out.println("Attribute Totals: " + Arrays.deepToString(attributeMeans));
            // computes the means
            for (int j = 0; j < classValueCounts.length; j++) {
                for (int i = 0; i < ins.numAttributes() - 1; i++) {
                    attributeMeans[j][i] = attributeMeans[j][i] / classValueCounts[j];
                }
            }

            // calculate the variance
            for (int i = 0; i < data.size(); i++) {
                double cv = data.get(i).getClassValue();
                double atIn = data.get(i).getAttributeIndex();
                double squareDifference = Math
                        .pow(data.get(i).getAttributeValue() - attributeMeans[(int) cv][(int) atIn], 2);
                attributeVariance[(int) cv][(int) atIn] += squareDifference;
            }
            for (int j = 0; j < classValueCounts.length; j++) {
                for (int i = 0; i < ins.numAttributes() - 1; i++) {
                    attributeVariance[j][i] = attributeVariance[j][i] / (classValueCounts[j] - 1);
                    attributeVariance[j][i] = Math.sqrt(attributeVariance[j][i]);
                }
            }
            System.out.println("Attribute Means: " + Arrays.deepToString(attributeMeans));
            System.out.println("Variance: " + Arrays.deepToString(attributeVariance));
        }
    }

    /**
     * The method buildDiscreteClassifier discretizes the data and then builds a
     * classifer
     *
     * @param instnc
     * @return
     * @throws Exception
     */
    public void buildDiscreteClassifier(Instances ins) throws Exception {
        ins = discretize(ins);
        ins.setClassIndex(ins.numAttributes() - 1);
        countData = ins.size();
        // assigns the class position of the instance 
        classValueCounts = new int[ins.numClasses()];
        // store the values
        for (Instance line : ins) {
            double classValue = line.classValue();
            classValueCounts[(int) classValue]++;
            for (int i = 0; i < line.numAttributes() - 1; i++) {
                double attributeValue = line.value(i);
                DataFound d = new DataFound(attributeValue, classValue, i);
                int index = data.indexOf(d);
                // then it doesn't exist
                if (index == -1) {
                    data.add(d);
                } else {
                    data.get(index).incrementCount();
                }
            }
        }

    }

    /**
     * The method classifyInstance which should call your previous
     * distributionForInstance method and simply return the prediction as the
     * class with the largest probability
     *
     * @param instnc
     * @return
     * @throws Exception
     */
    @Override
    public double classifyInstance(Instance instnc) throws Exception {
        testCount++;
        double[] bayesCalculations;
        double actualClassValue = instnc.classValue();
        if ("d".equals(gausianOrDiscretise)) {
            bayesCalculations = distributionForDiscrete(instnc);
        } else {
            bayesCalculations = distributionForInstance(instnc);
        }
        double largest = 0;
        double largestIndex = 0;

        for (int i = 0; i < bayesCalculations.length; i++) {
            if (bayesCalculations[i] > largest) {
                largest = bayesCalculations[i];
                largestIndex = i;
            }
        }
        if (largestIndex == actualClassValue) {
            correctCount++;
        }

        return largestIndex;
    }

    public Instances discretize(Instances instnc) throws Exception {
        Discretize d = new Discretize();
        d.setInputFormat(instnc);
        Instances newData = Filter.useFilter(instnc, d);

        binCount = new double[d.getBins()];

        for (Instance line : newData) {
            for (int j = 0; j < newData.numAttributes() - 1; j++) {
                binCount[(int) line.value(j)]++;
            }
        }
        return newData;
    }

    /**
     *
     * The method distributionForInstance should work out the probabilities of
     * class membership for a single instance.
     *
     * @param instnc
     * @return
     * @throws Exception
     */
    @Override
    public double[] distributionForInstance(Instance instnc) throws Exception {

        if ("d".equals(gausianOrDiscretise)) {
            return super.distributionForInstance(instnc);
        }
        // creates a double array for storing the naive calculations for each class
        double[] prediction = new double[classValueCounts.length];
        for (int c = 0; c < classValueCounts.length; c++) {
            ArrayList<Double> likelihoods = new ArrayList<>();
            double priorProbability = classValueCounts[c] / countData;
            likelihoods.add(priorProbability);
            for (int i = 0; i < instnc.numAttributes() - 1; i++) {
                double currentMean = attributeMeans[c][i];
                double currentVariance = attributeVariance[c][i];
                double attributeValue = instnc.value(i);

                double likelihood = 1 / (Math.sqrt(2 * Math.PI) * currentVariance)
                        * Math.exp(-Math.pow(attributeValue - currentMean, 2) / (2 * Math.pow(currentVariance, 2)));
                likelihoods.add(likelihood);
            }
            double total = 1;
            for (Double x : likelihoods) {
                total *= x;
            }
            prediction[c] = total;
        }
        return prediction;
    }

    /**
     *
     * The method distributionForInstance should work out the probabilities of
     * class membership for a single instance.
     *
     * @param instnc
     * @return
     * @throws Exception
     */
    public double[] distributionForDiscrete(Instance instnc) throws Exception {

        // creates a double array for storing the naive calculations for each class
        double[] naiveBayes = new double[classValueCounts.length];

        // loops through each class and computes the naive bayes 
        for (int c = 0; c < naiveBayes.length; c++) {

            // stores all conditional probabilities for class membership such:
            // P(struct=0|crime=1), P(security=1|crime=1), P(area=1|crime=1)
            // and also it stores the prior probability: P(crime=1)
            ArrayList<Double> conditionalProbs = new ArrayList<>();
            double priorProbability = classValueCounts[c] / countData;
            conditionalProbs.add(priorProbability);
            for (int i = 0; i < instnc.numValues() - 1; i++) {
                double attributeValue = instnc.value(i);
                DataFound d = new DataFound(attributeValue, c, i);

                int index = data.indexOf(d);
                if (index != -1) {
                    double classValueCount = classValueCounts[(int) d.getClassValue()];
                    conditionalProbs.add(data.get(index).getConditionalProbability((int) classValueCount));
                }
            }
            // compute the naive bayes
            double total = 1;
            for (Double x : conditionalProbs) {
                total *= x;
            }
            naiveBayes[c] = total;
        }
        return naiveBayes;
    }

    /**
     *
     * @return
     */
    @Override
    public Capabilities getCapabilities() {
        throw new UnsupportedOperationException("Not supported yet."); //To change body of generated methods, choose Tools | Templates.
    }

    public void prettyPrintProbabilities(double[] x) {
        System.out.println(Arrays.toString(x));
        double total = 0;
        for (int i = 0; i < x.length; i++) {
            total += x[i];
        }

        for (int i = 0; i < x.length; i++) {
            double probability = (x[i] / total);
            System.out.println("Probability of " + i + " Membership :" + (probability * 100) + "%");
        }

    }

    public void getAccuracy() {
        double percent = (correctCount / testCount) * 100;
        DecimalFormat df = new DecimalFormat("#.####");
        System.out.print(df.format(percent) + " %");
    }
}