Example usage for weka.core Instances setClassIndex

List of usage examples for weka.core Instances setClassIndex

Introduction

In this page you can find the example usage for weka.core Instances setClassIndex.

Prototype

public void setClassIndex(int classIndex) 

Source Link

Document

Sets the class index of the set.

Usage

From source file:anndl.Anndl.java

private static void buildModel(InputStream input) throws Exception {
    ANNDLLexer lexer = new ANNDLLexer(new ANTLRInputStream(input));
    CommonTokenStream tokens = new CommonTokenStream(lexer);
    ANNDLParser parser = new ANNDLParser(tokens);
    ParseTree tree = parser.model();/*w ww  . ja va  2 s.c o m*/

    ModelVisitor visitor = new ModelVisitor();

    ModelClassifier themodel = (ModelClassifier) visitor.visit(tree);
    //themodel.PrintInfo();
    themodel.extracthidden();

    System.out.println("Membaca File Training...");
    DataSource trainingsoure = new DataSource(themodel.filetraining);
    Instances trainingdata = trainingsoure.getDataSet();
    if (trainingdata.classIndex() == -1) {
        trainingdata.setClassIndex(trainingdata.numAttributes() - 1);
    }

    System.out.println("Melakukan konfigurasi ANN ... ");
    MultilayerPerceptron mlp = new MultilayerPerceptron();
    mlp.setLearningRate(themodel.learningrate);
    mlp.setMomentum(themodel.momentum);
    mlp.setTrainingTime(themodel.epoch);
    mlp.setHiddenLayers(themodel.hidden);

    System.out.println("Melakukan Training data ...");
    mlp.buildClassifier(trainingdata);

    Debug.saveToFile(themodel.namamodel + ".model", mlp);

    System.out.println("\n~~ .. ~~ .. ~~ .. ~~ .. ~~ .. ~~ .. ~~ .. ~~ .. ~~ ..");
    System.out.println("Model ANN Berhasil Diciptakan dengan nama file : " + themodel.namamodel + ".model");
    System.out.println("~~ .. ~~ .. ~~ .. ~~ .. ~~ .. ~~ .. ~~ .. ~~ .. ~~ .. \n");

}

From source file:ANN_Single.SinglelayerPerceptron.java

public static void main(String[] args) throws Exception {
    ConverterUtils.DataSource source = new ConverterUtils.DataSource(
            ("D:\\Program Files\\Weka-3-8\\data\\diabetes.arff"));
    Instances train = source.getDataSet();
    Normalize nm = new Normalize();
    nm.setInputFormat(train);/*ww w . ja v a  2 s . c  o  m*/
    train = Filter.useFilter(train, nm);
    train.setClassIndex(train.numAttributes() - 1);
    System.out.println();
    //                System.out.println(i + " "+0.8);
    SinglelayerPerceptron slp = new SinglelayerPerceptron(train, 0.1, 5000);
    slp.buildClassifier(train);
    Evaluation eval = new Evaluation(train);
    //                eval.crossValidateModel(slp, train, 10, new Random(1));
    eval.evaluateModel(slp, train);
    System.out.println(eval.toSummaryString());
    System.out.print(eval.toMatrixString());
}

From source file:ANN_single2.MultilayerPerceptron.java

public static void main(String[] args) throws Exception {
    ConverterUtils.DataSource source = new ConverterUtils.DataSource(
            ("D:\\Program Files\\Weka-3-8\\data\\Team.arff"));
    Instances train = source.getDataSet();
    Normalize nm = new Normalize();
    nm.setInputFormat(train);//from   ww w.  j  ava 2s.c  o m
    train = Filter.useFilter(train, nm);
    train.setClassIndex(train.numAttributes() - 1);
    MultilayerPerceptron slp = new MultilayerPerceptron(train, 13, 0.1, 0.5);
    //        slp.buildClassifier(train);
    Evaluation eval = new Evaluation(train);
    eval.crossValidateModel(slp, train, 10, new Random(1));
    //        eval.evaluateModel(slp, train);
    System.out.println(eval.toSummaryString());
    System.out.println(eval.toMatrixString());
}

From source file:ANN_single2.SinglelayerPerceptron.java

public static void main(String[] args) throws Exception {
    ConverterUtils.DataSource source = new ConverterUtils.DataSource(
            ("D:\\Program Files\\Weka-3-8\\data\\Team.arff"));
    Instances train = source.getDataSet();
    Normalize nm = new Normalize();
    nm.setInputFormat(train);//  ww  w .  ja  va 2s .c  o m
    train = Filter.useFilter(train, nm);
    train.setClassIndex(train.numAttributes() - 1);
    for (int i = 100; i < 3000; i += 100) {
        for (double j = 0.01; j < 1; j += 0.01) {
            System.out.println(i + " " + j);
            SinglelayerPerceptron slp = new SinglelayerPerceptron(i, j, 0.00);
            slp.buildClassifier(train);
            Evaluation eval = new Evaluation(train);
            //                eval.crossValidateModel(slp, train,10, new Random(1));
            eval.evaluateModel(slp, train);
            System.out.println(eval.toSummaryString());
            System.out.println(eval.toMatrixString());
        }
    }
}

From source file:ap.mavenproject1.HelloWeka.java

public static void main(String args[]) {
    Instances data = null;
    ArffLoader loader = new ArffLoader();
    try {/*from   w w  w .  j  a  v  a  2 s . com*/

        loader.setFile(new File("C:\\Users\\USER\\Desktop\\data.arff"));

        data = loader.getDataSet();
        data.setClassIndex(data.numAttributes() - 1);

    } catch (IOException ex) {
        Logger.getLogger(HelloWeka.class.getName()).log(Level.SEVERE, null, ex);
    }

    Apriori apriori = new Apriori();
    try {
        NumericToNominal numericToNominal = new NumericToNominal();
        numericToNominal.setInputFormat(data);

        Instances nominalData = Filter.useFilter(data, numericToNominal);
        apriori.buildAssociations(nominalData);
        FastVector[] allTheRules;
        allTheRules = apriori.getAllTheRules();
        for (int i = 0; i < allTheRules.length; i++) {
            System.out.println(allTheRules[i]);
        }
        //             BufferedWriter writer = new BufferedWriter(new FileWriter("./output.arff"));
        //                writer.write(nominalData.toString());
        //                writer.flush();
        //                writer.close();

    } catch (Exception ex) {
        Logger.getLogger(HelloWeka.class.getName()).log(Level.SEVERE, null, ex);
    }

}

From source file:asap.CrossValidation.java

/**
 *
 * @param dataInput//from w  w w .  jav a2 s . com
 * @param classIndex
 * @param removeIndices
 * @param cls
 * @param seed
 * @param folds
 * @param modelOutputFile
 * @return
 * @throws Exception
 */
public static String performCrossValidation(String dataInput, String classIndex, String removeIndices,
        AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception {

    PerformanceCounters.startTimer("cross-validation ST");

    PerformanceCounters.startTimer("cross-validation init ST");

    // loads data and set class index
    Instances data = DataSource.read(dataInput);
    String clsIndex = classIndex;

    switch (clsIndex) {
    case "first":
        data.setClassIndex(0);
        break;
    case "last":
        data.setClassIndex(data.numAttributes() - 1);
        break;
    default:
        try {
            data.setClassIndex(Integer.parseInt(clsIndex) - 1);
        } catch (NumberFormatException e) {
            data.setClassIndex(data.attribute(clsIndex).index());
        }
        break;
    }

    Remove removeFilter = new Remove();
    removeFilter.setAttributeIndices(removeIndices);
    removeFilter.setInputFormat(data);
    data = Filter.useFilter(data, removeFilter);

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval = new Evaluation(randData);
    Instances trainSets[] = new Instances[folds];
    Instances testSets[] = new Instances[folds];
    Classifier foldCls[] = new Classifier[folds];

    for (int n = 0; n < folds; n++) {
        trainSets[n] = randData.trainCV(folds, n);
        testSets[n] = randData.testCV(folds, n);
        foldCls[n] = AbstractClassifier.makeCopy(cls);
    }

    PerformanceCounters.stopTimer("cross-validation init ST");
    PerformanceCounters.startTimer("cross-validation folds+train ST");
    //paralelize!!:--------------------------------------------------------------
    for (int n = 0; n < folds; n++) {
        Instances train = trainSets[n];
        Instances test = testSets[n];

        // the above code is used by the StratifiedRemoveFolds filter, the
        // code below by the Explorer/Experimenter:
        // Instances train = randData.trainCV(folds, n, rand);
        // build and evaluate classifier
        Classifier clsCopy = foldCls[n];
        clsCopy.buildClassifier(train);
        eval.evaluateModel(clsCopy, test);
    }

    cls.buildClassifier(data);
    //until here!-----------------------------------------------------------------

    PerformanceCounters.stopTimer("cross-validation folds+train ST");
    PerformanceCounters.startTimer("cross-validation post ST");
    // output evaluation
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (!modelOutputFile.isEmpty()) {
        SerializationHelper.write(modelOutputFile, cls);
    }

    PerformanceCounters.stopTimer("cross-validation post ST");
    PerformanceCounters.stopTimer("cross-validation ST");

    return out;
}

From source file:asap.CrossValidation.java

/**
 *
 * @param dataInput/* w w  w.  j  av a2  s. c o m*/
 * @param classIndex
 * @param removeIndices
 * @param cls
 * @param seed
 * @param folds
 * @param modelOutputFile
 * @return
 * @throws Exception
 */
public static String performCrossValidationMT(String dataInput, String classIndex, String removeIndices,
        AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception {

    PerformanceCounters.startTimer("cross-validation MT");

    PerformanceCounters.startTimer("cross-validation init MT");

    // loads data and set class index
    Instances data = DataSource.read(dataInput);
    String clsIndex = classIndex;

    switch (clsIndex) {
    case "first":
        data.setClassIndex(0);
        break;
    case "last":
        data.setClassIndex(data.numAttributes() - 1);
        break;
    default:
        try {
            data.setClassIndex(Integer.parseInt(clsIndex) - 1);
        } catch (NumberFormatException e) {
            data.setClassIndex(data.attribute(clsIndex).index());
        }
        break;
    }

    Remove removeFilter = new Remove();
    removeFilter.setAttributeIndices(removeIndices);
    removeFilter.setInputFormat(data);
    data = Filter.useFilter(data, removeFilter);

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval = new Evaluation(randData);
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                AbstractClassifier.makeCopy(cls)));

        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init MT");
    PerformanceCounters.startTimer("cross-validation folds+train MT");
    //paralelize!!:--------------------------------------------------------------
    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        //use the current thread to run the cross-validation instead of using the Thread instance created here:
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    cls.buildClassifier(data);

    for (Thread foldThread : foldThreads) {
        foldThread.join();
    }

    //until here!-----------------------------------------------------------------
    PerformanceCounters.stopTimer("cross-validation folds+train MT");
    PerformanceCounters.startTimer("cross-validation post MT");
    // evaluation for output:
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (!modelOutputFile.isEmpty()) {
        SerializationHelper.write(modelOutputFile, cls);
    }

    PerformanceCounters.stopTimer("cross-validation post MT");
    PerformanceCounters.stopTimer("cross-validation MT");
    return out;
}

From source file:assign00.ExperimentShell.java

/**
 * @param args the command line arguments
 *//*from   w w  w.  java2 s. c  o m*/
public static void main(String[] args) throws Exception {
    DataSource source = new DataSource(file);
    Instances dataSet = source.getDataSet();

    //Set up data
    dataSet.setClassIndex(dataSet.numAttributes() - 1);
    dataSet.randomize(new Random(1));

    //determine sizes
    int trainingSize = (int) Math.round(dataSet.numInstances() * .7);
    int testSize = dataSet.numInstances() - trainingSize;

    Instances training = new Instances(dataSet, 0, trainingSize);

    Instances test = new Instances(dataSet, trainingSize, testSize);

    Standardize standardizedData = new Standardize();
    standardizedData.setInputFormat(training);

    Instances newTest = Filter.useFilter(test, standardizedData);
    Instances newTraining = Filter.useFilter(training, standardizedData);

    NeuralNetworkClassifier NWC = new NeuralNetworkClassifier();
    NWC.buildClassifier(newTraining);

    Evaluation eval = new Evaluation(newTraining);
    eval.evaluateModel(NWC, newTest);

    System.out.println(eval.toSummaryString("\nResults\n======\n", false));
}

From source file:at.aictopic1.sentimentanalysis.machinelearning.impl.TwitterClassifer.java

public Instances loadTrainingData() {

    try {/*from   w w  w. j  a v  a 2 s  .  c  o m*/
        //DataSource source = new DataSource("C:\\Users\\David\\Documents\\Datalogi\\TU Wien\\2014W_Advanced Internet Computing\\Labs\\aic_group2_topic1\\Other Stuff\\training_dataset.arff");
        DataSource source = new DataSource(
                "C:\\Users\\David\\Documents\\Datalogi\\TU Wien\\2014W_Advanced Internet Computing\\Labs\\Data sets\\labelled.arff");

        //            System.out.println("Data Structure pre processing: " + source.getStructure());
        Instances data = source.getDataSet();

        // Get and save the dataStructure of the dataset
        dataStructure = source.getStructure();
        try {
            // Save the datastructure to file
            // serialize dataStructure
            weka.core.SerializationHelper.write(modelDir + algorithm + ".dataStruct", dataStructure);
        } catch (Exception ex) {
            Logger.getLogger(TwitterClassifer.class.getName()).log(Level.SEVERE, null, ex);
        }
        // Set class index
        data.setClassIndex(2);

        // Giving attributes unique names before converting strings
        data.renameAttribute(2, "class_attr");
        data.renameAttribute(0, "twitter_id");

        // Convert String attribute to Words using filter
        StringToWordVector filter = new StringToWordVector();

        filter.setInputFormat(data);

        Instances filteredData = Filter.useFilter(data, filter);

        System.out.println("filteredData struct: " + filteredData.attribute(0));
        System.out.println("filteredData struct: " + filteredData.attribute(1));
        System.out.println("filteredData struct: " + filteredData.attribute(2));

        return filteredData;

    } catch (Exception ex) {
        System.out.println("Error loading training set: " + ex.toString());
        return null;
        //Logger.getLogger(Trainer.class.getName()).log(Level.SEVERE, null, ex);
    }

}

From source file:at.aictopic1.sentimentanalysis.machinelearning.impl.TwitterClassifer.java

public Integer classify(Tweet[] tweets) {
    // TEST/* w  ww.  ja v a2s.  c  om*/

    // Generate two tweet examples
    Tweet exOne = new Tweet("This is good and fantastic");
    exOne.setPreprocessedText("This is good and fantastic");
    Tweet exTwo = new Tweet("Horribly, terribly bad and more");
    exTwo.setPreprocessedText("Horribly, terribly bad and more");
    Tweet exThree = new Tweet(
            "I want to update lj and read my friends list, but I\\'m groggy and sick and blargh.");
    exThree.setPreprocessedText(
            "I want to update lj and read my friends list, but I\\'m groggy and sick and blargh.");
    Tweet exFour = new Tweet("bad hate worst sick");
    exFour.setPreprocessedText("bad hate worst sick");
    tweets = new Tweet[] { exOne, exTwo, exThree, exFour };
    // TEST

    // Load model
    //        loadModel();
    // Convert Tweet to Instance type 
    // Get String Data
    // Create attributes for the Instances set
    Attribute twitter_id = new Attribute("twitter_id");
    //        Attribute body = new Attribute("body");

    FastVector classVal = new FastVector(2);
    classVal.addElement("pos");
    classVal.addElement("neg");

    Attribute class_attr = new Attribute("class_attr", classVal);

    // Add them to a list
    FastVector attrVector = new FastVector(3);
    //        attrVector.addElement(twitter_id);
    //        attrVector.addElement(new Attribute("body", (FastVector) null));
    //        attrVector.addElement(class_attr);

    // Get the number of tweets and then create predictSet
    int numTweets = tweets.length;
    Enumeration structAttrs = dataStructure.enumerateAttributes();

    //        ArrayList<Attribute> attrList = new ArrayList<Attribute>(dataStructure.numAttributes());
    while (structAttrs.hasMoreElements()) {
        attrVector.addElement((Attribute) structAttrs.nextElement());
    }
    Instances predictSet = new Instances("predictInstances", attrVector, numTweets);
    //        Instances predictSet = new Instances(dataStructure);
    predictSet.setClassIndex(2);

    // init prediction
    double prediction = -1;

    System.out.println("PredictSet matches source structure: " + predictSet.equalHeaders(dataStructure));

    System.out.println("PredSet struct: " + predictSet.attribute(0));
    System.out.println("PredSet struct: " + predictSet.attribute(1));
    System.out.println("PredSet struct: " + predictSet.attribute(2));
    // Array to return predictions 
    //double[] tweetsClassified = new double[2][numTweets];
    //List<Integer, Double> tweetsClass = new ArrayList<Integer, Double>(numTweets);
    for (int i = 0; i < numTweets; i++) {
        String content = (String) tweets[i].getPreprocessedText();

        System.out.println("Tweet content: " + content);

        //            attrList
        Instance tweetInstance = new Instance(predictSet.numAttributes());

        tweetInstance.setDataset(predictSet);

        tweetInstance.setValue(predictSet.attribute(0), i);
        tweetInstance.setValue(predictSet.attribute(1), content);
        tweetInstance.setClassMissing();

        predictSet.add(tweetInstance);

        try {
            // Apply string filter
            StringToWordVector filter = new StringToWordVector();

            filter.setInputFormat(predictSet);
            Instances filteredPredictSet = Filter.useFilter(predictSet, filter);

            // Apply model
            prediction = trainedModel.classifyInstance(filteredPredictSet.instance(i));
            filteredPredictSet.instance(i).setClassValue(prediction);
            System.out.println("Classification: " + filteredPredictSet.instance(i).toString());
            System.out.println("Prediction: " + prediction);

        } catch (Exception ex) {
            Logger.getLogger(TwitterClassifer.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    return 0;
}