Cluster Bagging J48 weka - Java Machine Learning AI

Java examples for Machine Learning AI:weka

Description

Cluster Bagging J48 weka

Demo Code

import java.io.FileWriter;
import java.util.ArrayList;

import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.Prediction;
import weka.classifiers.meta.Bagging;
import weka.classifiers.trees.J48;
import weka.clusterers.SimpleKMeans;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Add;
import weka.filters.unsupervised.instance.RemoveFrequentValues;
import au.com.bytecode.opencsv.CSVWriter;

public class Main {

    public static void main(String[] args) throws Exception {
        /*from   www.j a  va2s  .c o  m*/
        Instances train = DataSource
        read("./train1.arff");
        int cid1 = train.numAttributes() - 1;
        train.setClassIndex(cid1);


        Instances validation = DataSource
        read("./validation1.arff");
        int cid2 = validation.numAttributes() - 1;
        validation.setClassIndex(cid2);


        Instances test = DataSource
        read("./test1.arff");
        int cid3 = test.numAttributes() - 1;
        test.setClassIndex(cid3);

        //Remove fraud class instances
        RemoveFrequentValues remove = new RemoveFrequentValues();
        remove.setInputFormat(train);
        remove.setAttributeIndex("last");
        remove.setNumValues(1);
        Instances train_ok = Filter.useFilter(train, remove);
        int cid4 = train_ok.numAttributes() - 1;
        train_ok.setClassIndex(cid4);

        //Remove ok class instances
        RemoveFrequentValues remove1 = new RemoveFrequentValues();
        remove1.setInputFormat(train);
        remove1.setAttributeIndex("last");
        remove1.setNumValues(1);
        remove1.setUseLeastValues(true);
        Instances train_fraud = Filter.useFilter(train, remove1);
        int cid5 = train_fraud.numAttributes() - 1;
        train_fraud.setClassIndex(cid5);

        //remove class attribute for clustering
        weka.filters.unsupervised.attribute.Remove filter = new weka.filters.unsupervised.attribute.Remove();
        filter.setAttributeIndices("" + (train_ok.classIndex() + 1));
        filter.setInputFormat(train_ok);
        Instances dataClusterer = Filter.useFilter(train_ok, filter);

        //cluster using K-means
        SimpleKMeans cluster = new SimpleKMeans();
        cluster.setNumClusters(146);
        cluster.buildClusterer(dataClusterer);
        train_ok = cluster.getClusterCentroids();

        //Add deleted class attribute
        Add add_attribute = new Add();
        add_attribute.setAttributeName("status");
        add_attribute.setAttributeIndex("last");
        add_attribute.setNominalLabels("0,1");
        //SelectedTag value=
        //add_attribute.setAttributeType(value);
        add_attribute.setInputFormat(train_ok);
        train_ok = Filter.useFilter(train_ok, add_attribute);
        for (int i = 0; i < train_ok.numInstances(); i++) {
            train_ok.instance(i)
            setValue(train_ok.numAttributes() - 1, "0");
        }
        int cid7 = train_ok.numAttributes() - 1;
        train_ok.setClassIndex(cid7);

        //combine train_ok and train_fraud
        for (int i = 0; i < train_fraud.numInstances(); i++)
            train_ok.add(train_fraud.instance(i));
        train = train_ok;
        int cid6 = train.numAttributes() - 1;
        train.setClassIndex(cid6);

        //Bagging J48
        J48 jtree = new J48();

        Bagging tree = new Bagging();
        tree.setClassifier(jtree);
        tree.buildClassifier(train);

        Evaluation eval = new Evaluation(train);
        eval.evaluateModel(tree, validation);
        System.out.println(eval.toSummaryString("\nResults_RF\n\n", false));
        System.out.println(eval.toClassDetailsString());
        System.out.println(eval.toMatrixString());

        ArrayList<Prediction> al = eval.predictions();
        ArrayList<String[]> as = new ArrayList<String[]>(al.size());
        for (int i = 0; i < al.size(); i++) {
            String[] s = new String[1];
            s[0] = al.get(i).toString();
            s[0] = s[0].substring(9, 11);
            as.add(s);
        }
        ArrayList<String[]> li = new ArrayList<String[]>(al.size());
        li.addAll(as);

        String csv = "./output.csv";
        CSVWriter writer = new CSVWriter(new FileWriter(csv));

        writer.writeAll(li);
        writer.close();
    }

}

Related Tutorials