Use k-nearest neighbors search via weka - Java Machine Learning AI

Java examples for Machine Learning AI:weka

Description

Use k-nearest neighbors search via weka

Demo Code

import java.util.Arrays;

import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;

public class KNN {

    static Instances data;
    static int k;
    static int classIndex;
    static boolean printOn;

    /**/*from  w w  w. j a v  a  2s  .  c o  m*/
     * @param args
     * @throws Exception 
     */
    public static void main(String[] args) {
        String file = null;
        DataSource source;
        try {
            file = args[0];
            source = new DataSource(file);
            data = source.getDataSet();
        } catch (Exception e) {
            System.err.println("Cannot read from file " + file);
            return;
        }

        try {
            k = Integer.parseInt(args[1]);
        } catch (Exception e) {
            // Default 3, if k is unset or set to invalid value.
            k = 3;
        }

        try {
            // args[2] is the class name.
            classIndex = data.attribute(args[2]).index();
        } catch (Exception e) {
            // Default class index.
            classIndex = data.numAttributes() - 1;
            if (file.equals("autos.arff"))
                classIndex = data.numAttributes() - 2;
        }

        try {
            String p = args[3];
            printOn = false;
        } catch (Exception e) {
            printOn = true;
        }

        if (data.classIndex() < 0)
            data.setClassIndex(classIndex);

        normalize();

        doKNN();
    }

    /**
     * Normalize all numeric attributes to [0, 1]. 
     * Also delete the instances with missing attributes.
     */
    private static void normalize() {
        // Do normalization to each attribute.
        for (int attIndex = 0; attIndex < data.numAttributes(); attIndex++) {
            // Delete the instances with missing value of this attribute.
            data.deleteWithMissing(attIndex);

            if (data.attribute(attIndex).isNominal())
                continue;
            if (attIndex == classIndex)
                continue;

            // Normalize non-class and non-nominal attributes.
            double max = data.instance(0).value(attIndex);
            double min = max;
            // Find the max and min value of this attribute in the data set.
            for (int insIndex = 1; insIndex < data.numInstances(); insIndex++) {
                double value = data.instance(insIndex).value(attIndex);
                if (max < value)
                    max = value;
                if (min > value)
                    min = value;
            }
            //System.out.println("max="+max+",\tmin="+min+",\t"+data.attribute(attIndex).name());
            if (max == min)
                // No need to normalize if the value of this attribute is a constant.
                continue;
            // Normalize the value of this attribute in each instance to [0, 1].
            for (int insIndex = 0; insIndex < data.numInstances(); insIndex++) {
                double value = data.instance(insIndex).value(attIndex);
                double value_nm = (value - min) / (max - min);
                data.instance(insIndex).setValue(attIndex, value_nm);
            }
        }
        //System.out.println(data);
    }

    /**
     * 
     */
    private static void doKNN() {
        int testIndex;
        int numInstances = data.numInstances();
        int numClasses = data.numClasses();
        int numErrors = 0; // for nonimal prediction(classification)
        double[] errRate = new double[numInstances]; // for numeric prediction
        boolean isNominal = data.classAttribute().isNominal();
        boolean isNumeric = data.classAttribute().isNumeric();

        // Leave One Out Cross Validation.
        for (testIndex = 0; testIndex < numInstances; testIndex++) {
            if (printOn)
                System.out.printf("Instance %4d for testing.\t", testIndex);

            // Compute the distance to every instance in the data set
            // except the test instance itself. 
            int index = 0;
            double[] distanceTo = new double[numInstances];
            for (index = 0; index < numInstances; index++) {
                if (index == testIndex)
                    continue;
                distanceTo[index] = computeDistance(index, testIndex);
            }
            // Distance to myself is the largest.
            distanceTo[testIndex] = Double.MAX_VALUE;


            // Find the indexes of the k nearest neighbours.
            int[] nearestNbour = new int[k];
            double[] sortedDist = new double[numInstances];
            System.arraycopy(distanceTo, 0, sortedDist, 0, numInstances);
            Arrays.sort(sortedDist);


            for (int i = 0; i < k; i++) {
                if (i < k - 1 && sortedDist[i] == sortedDist[i + 1])
                    continue;
                for (index = 0; index < numInstances; index++) {
                    if (distanceTo[index] == sortedDist[i]) {
                        nearestNbour[i] = index;
                        if ((++i) == k)
                            break;
                    }
                }
            }

            if (isNominal) {
                // Each nearest neighbour gives a vote to its class value.
                String[] classvalue = new String[numClasses];
                int[] vote = new int[numClasses];
                for (int i = 0; i < numClasses; i++) {
                    classvalue[i] = data.classAttribute().value(i);
                    vote[i] = 0;
                }
                for (int j = 0; j < k; j++) {
                    String thisclass = data.instance(nearestNbour[j])
                            .stringValue(classIndex);
                    int i;
                    for (i = 0; i < numClasses; i++)
                        if (classvalue[i].equals(thisclass))
                            break;
                    vote[i]++;
                }
                // Find the most-voted class value as the prediction.
                int maxVote = 0;
                for (int i = 0; i < numClasses; i++) {
                    if (maxVote < vote[i])
                        maxVote = vote[i];
                }
                String prediction = "neverseethis";
                for (int i = 0; i < numClasses; i++) {
                    if (vote[i] == maxVote) {
                        prediction = classvalue[i];
                        break;
                    }
                }
                String target = data.instance(testIndex).stringValue(
                        classIndex);
                boolean correct = false;
                if (prediction.equals(target))
                    correct = true;
                else
                    numErrors++;
                if (printOn)
                    System.out.println("prediction=" + prediction
                            + ",\ttarget=" + target + ",\t" + correct);
            } 

   
            if (isNumeric) {
                double prediction, target;
                double[] nbClass = new double[k]; // class values of the nearest neighbours.
                double nbClassSum = 0;
                for (int i = 0; i < k; i++) {
                    nbClass[i] = data.instance(nearestNbour[i]).value(
                            classIndex);
                    nbClassSum += nbClass[i];
                }
                prediction = nbClassSum / k;
                target = data.instance(testIndex).value(classIndex);
                //TODO check error rate measure.
                errRate[testIndex] = Math.abs((prediction - target)
                        / target);
                if (printOn)
                    System.out
                            .printf("prediction=%.3f,\ttarget=%.3f,\terrRate=%.4f\n",
                                    prediction, target, errRate[testIndex]);
            } // end of if(isNumeric)
        } // end of for(testIndex)
          // Print LOOCV evaluation results.
        System.out.print("    LOOCV evaluation result:");
        System.out.println("algorithm:\t\t" + k + " Nearest Neighbour");
        System.out.println("relation:\t\t" + data.relationName());
        System.out.println("class attribute:\t"
                + data.classAttribute().name());
        System.out.print("class type:\t\t");
        double errorRate;
        int numTests = numInstances; // since it is LOOCV
        if (isNominal) {
            errorRate = (double) numErrors / (double) numTests;
            System.out.println("Nominal");
            System.out.println("Number of errors:\t" + numErrors
                    + "\nNumber of tests:\t" + numTests);
            System.out.println("Error Rate:\t\t" + errorRate);
        }
        if (isNumeric) {
            double errorRateSum = 0;
            for (int i = 0; i < numTests; i++)
                errorRateSum += errRate[i];
            errorRate = errorRateSum / (double) numTests;
            System.out.println("Numeric");
            System.out.println("Number of tests:\t" + numTests);
            System.out.println("Average Error Rate:\t" + errorRate);
        }
    } // end of doKNN()


    private static double computeDistance(int ins1, int ins2) {
        // Manhattan distance
        //TODO other distance?
        int numAtts = data.numAttributes();
        double distance = 0;
        for (int attIndex = 0; attIndex < numAtts; attIndex++) {
            if (attIndex == classIndex)
                continue;
            if (data.attribute(attIndex).isNominal()) {
                if (!data.instance(ins1).stringValue(attIndex)
                        .equals(data.instance(ins2).stringValue(attIndex))) {
                    // Distance between two different nominal value is 1.
                    distance += 1;
                    continue;
                }
            }
            // Else, the attributes is Numeric.
            distance += Math.abs(data.instance(ins1).value(attIndex)
                    - data.instance(ins2).value(attIndex));
        }
        return distance;
    }
}

Related Tutorials