Example usage for weka.core Instances numInstances

List of usage examples for weka.core Instances numInstances

Introduction

In this page you can find the example usage for weka.core Instances numInstances.

Prototype


publicint numInstances() 

Source Link

Document

Returns the number of instances in the dataset.

Usage

From source file:ann.MyANN.java

/**
 * mengevaluasi model dengan testSet dan mengembalikan Confusion Matrix
 * buildClassifier harus dipanggil terlebih dahulu
 * @param testSet testSet untuk menguji model
 * @return confusion Matrix, nominal = matrix persegi berukuran NxN dengan
 * N adalah jumlah kelas. numerik = matrix 1x2 dengan elemen pertama adalah 
 * jumlah prediksi yang benar dan elemen kedua adalah jumlah prediksi yang salah
 *//*from   w  ww .  j  a  va  2  s.  c  o  m*/
public int[][] evaluate(Instances testSet) {
    int[][] confusionMatrix;
    if (testSet.classAttribute().isNominal()) {
        confusionMatrix = new int[testSet.classAttribute().numValues()][testSet.classAttribute().numValues()];
    } else {
        confusionMatrix = new int[1][2];
    }
    // debug
    for (int i = 0; i < testSet.numInstances(); i++) {
        //            System.out.println("cv: "+testSet.instance(i).classValue());
    }

    for (int i = 0; i < testSet.numInstances(); i++) {
        try {
            double[] prob = distributionForInstance(testSet.instance(i));
            //                System.out.println("probl:"+prob.length);
            //                System.out.println("i: "+testSet.instance(i));
            if (testSet.classAttribute().isNominal()) {
                int idx = predictClassIndex(prob);
                confusionMatrix[(int) testSet.instance(i).classValue()][idx]++;
            } else {
                if (Math.abs(prob[0] - testSet.instance(i).classValue()) <= 0.001)
                    confusionMatrix[0][0]++;
                else
                    confusionMatrix[0][1]++;
            }
        } catch (Exception ex) {
            Logger.getLogger(MyANN.class.getName()).log(Level.SEVERE, null, ex);
        }
    }
    return confusionMatrix;
}

From source file:ann.MyANN.java

/**
 * mengubah instances ke dalam array of data dan disimpan ke variabel datas
 * @param instances input yang akan diubah ke dalam array of data
 *///from   w w w . j  ava 2 s .  c o m
private void instancesToDatas(Instances instances) {
    datas = new ArrayList<>();

    for (int i = 0; i < instances.numInstances(); i++) {
        datas.add(instanceToData(instances.instance(i)));
    }
}

From source file:ann.SingleLayerPerceptron.java

public void doPerceptron(Instances data) {
    for (int epoch = 0; epoch < annOptions.maxIteration; epoch++) {
        double deltaWeight = 0.0;
        double[] deltaWeightUpdate = new double[data.numAttributes()];
        for (int i = 0; i < data.numAttributes(); i++) {
            deltaWeightUpdate[i] = 0;/*from w  w w  .  j a  v  a  2  s .  c  o  m*/
        }
        for (int i = 0; i < data.numInstances(); i++) {
            // do sum xi.wi (nilai data * bobot)
            for (int j = 0; j < output.size(); j++) {
                double sum = 0;
                double weight, input;
                for (int k = 0; k < data.numAttributes(); k++) {
                    if (k == data.numAttributes() - 1) { // bias
                        input = 1;
                    } else {
                        input = data.instance(i).value(k);
                    }
                    weight = output.get(j).weights.get(k);
                    sum += weight * input;
                }

                // Update input weight
                for (int k = 0; k < data.numAttributes(); k++) {
                    if (k == data.numAttributes() - 1) { // bias
                        input = 1;
                    } else {
                        input = data.instance(i).value(k);
                    }

                    // lewati fungsi aktivasi
                    double newOutput = Util.activationFunction(sum, annOptions);
                    double target;
                    if (output.size() > 1) {
                        if (data.instance(i).classValue() == j) {
                            target = 1;
                        } else {
                            target = 0;
                        }
                    } else {
                        target = data.instance(i).classValue();
                    }
                    weight = output.get(j).weights.get(k);

                    // hitung delta weight -> learning rate * (T-O) * xi
                    if (annOptions.topologyOpt == 2) // batch
                    {
                        deltaWeightUpdate[k] += (target - newOutput) * input;
                        if (i == data.numInstances() - 1) { // update weight
                            output.get(j).weights.set(k,
                                    annOptions.learningRate * (weight + deltaWeightUpdate[k]));
                        }
                    } else {
                        deltaWeight = annOptions.learningRate * (target - newOutput) * input;
                        output.get(j).weights.set(k, weight + deltaWeight);
                    }
                }
            }
        }

        // hitung error
        double errorEpoch = 0;
        for (int i = 0; i < data.numInstances(); i++) {
            double sum = 0;
            for (int j = 0; j < output.size(); j++) {
                for (int k = 0; k < data.numAttributes(); k++) {
                    double input;
                    if (k == data.numAttributes() - 1) { // bias
                        input = 1;
                    } else {
                        input = data.instance(i).value(k);
                    }
                    double weight = output.get(j).weights.get(k);
                    sum += weight * input;
                }
                // lewati fungsi aktivasi
                sum = Util.activationFunction(sum, annOptions);
                double target;
                if (output.size() > 1) {
                    if (data.instance(i).classValue() == j) {
                        target = 1;
                    } else {
                        target = 0;
                    }
                } else {
                    target = data.instance(i).classValue();
                }
                double error = target - sum;
                errorEpoch += error * error;
            }
        }
        errorEpoch *= 0.5;
        // Convergent
        if (errorEpoch <= annOptions.threshold) {
            break;
        }
    }
}

From source file:ann.SingleLayerPerceptron.java

public int[] classifyInstances(Instances data) throws Exception {
    int[] classValue = new int[data.numInstances()];
    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();//  w  w w.  j  a v a 2 s  .  c  om

    //nominal to binary filter
    ntb.setInputFormat(data);
    data = new Instances(Filter.useFilter(data, ntb));
    int right = 0;

    for (int i = 0; i < data.numInstances(); i++) {
        int outputSize = output.size();
        double[] result = new double[outputSize];
        for (int j = 0; j < outputSize; j++) {
            result[j] = 0.0;
            for (int k = 0; k < data.numAttributes(); k++) {
                double input = 1;
                if (k < data.numAttributes() - 1) {
                    input = data.instance(i).value(k);
                }
                result[j] += output.get(j).weights.get(k) * input;
            }
            result[j] = Util.activationFunction(result[j], annOptions);
        }

        if (outputSize >= 2) {
            for (int j = 0; j < outputSize; j++) {
                if (result[j] > result[classValue[i]]) {
                    classValue[i] = j;
                }
            }
        } else {
            classValue[i] = (int) result[0];
        }
        double target = data.instance(i).classValue();
        double output = classValue[i];
        System.out.println("Intance-" + i + " target: " + target + " output: " + output);
        if (target == output) {
            right = right + 1;
        }
    }

    System.out.println("Percentage: " + ((double) right / (double) data.numInstances()));

    return classValue;
}

From source file:ANN_Single.SinglelayerPerceptron.java

public SinglelayerPerceptron(Instances i, double rate, int itter) {
    learningRate = rate;/*from  ww  w  .j  av a 2  s.com*/
    //        listOutput = new ArrayList<>();
    //        for (int num =0; num<i.numClasses(); num++) {
    //            listOutput.add(new Node(i.numAttributes()));
    //        }
    itteration = itter;
    listDoubleinstance = new double[i.numInstances()];
    for (int numIns = 0; numIns < i.numInstances(); numIns++) {
        listDoubleinstance[numIns] = i.instance(numIns).toDoubleArray()[i.classIndex()];
    }
}

From source file:ANN_Single.SinglelayerPerceptron.java

@Override
public void buildClassifier(Instances i) {
    listOutput = new ArrayList<>();
    for (int num = 0; num < i.numClasses(); num++) {
        listOutput.add(new Node(i.numAttributes()));
    }//from  www .ja va  2 s  . c o  m
    while (true) {//ulang iterasi
        //            System.out.println();
        //            System.out.println("iterasi "+itt);
        for (int idxInstance = 0; idxInstance < i.numInstances(); idxInstance++) {
            //buat list input
            //                 System.out.print(idxInstance+" ");
            ArrayList<Double> listInput = new ArrayList<>();
            listInput.add(1.0);
            for (int idx = 0; idx < i.numAttributes() - 1; idx++) {
                listInput.add(i.get(idxInstance).value(idx));
            }

            //hitung output layer
            for (int idxOutput = 0; idxOutput < listOutput.size(); idxOutput++) {
                output(listInput, idxOutput);
                //                    listOutput.get(idxOutput).setValue(outputVal);
                //                    System.out.print(listOutput.get(idxOutput).getValue()+" ");
            }
            //                System.out.println();
            //hitung error
            calculateError(idxInstance);
            //update bobot
            updateWeight(listInput);
        }
        double error = 0;
        for (int idxErr = 0; idxErr < i.numInstances(); idxErr++) {
            for (int idx = 0; idx < listOutput.size(); idx++) {
                error += Math.pow(listOutput.get(idx).getError(), 2) / 2;
                //                    System.out.println(listOutput.get(idx).getError());
            }
            //                System.out.println(error);
        }
        System.out.println(error);
        System.out.println();
        if (error <= 0)
            break;
    }
    fold++;
    System.out.println("Fold ke-" + fold);
    double error = 0;
    for (int idxErr = 0; idxErr < i.numInstances(); idxErr++) {
        for (Node listOutput1 : listOutput) {
            error += Math.pow(listOutput1.getError(), 2) / 2;
            //                    System.out.println(listOutput1.getError());
        }
        //                System.out.println(error);
    }
    System.out.println("error " + error);
    for (int idx = 0; idx < listOutput.size(); idx++) {
        System.out.println("Output value " + listOutput.get(idx).getValue());
        System.out.println("Output error " + listOutput.get(idx).getError());
        for (int idx2 = 0; idx2 < listOutput.get(idx).getWeightSize(); idx2++)
            System.out.println("Output weight" + listOutput.get(idx).getWeightFromList(idx2));
    }
}

From source file:ANN_single2.MultilayerPerceptron.java

@Override
public void buildClassifier(Instances i) {

    //mengubah class menjadi numeric (diambil indexnya)
    listDoubleinstance = new double[i.numInstances()];
    for (int numIns = 0; numIns < i.numInstances(); numIns++) {
        listDoubleinstance[numIns] = i.instance(numIns).toDoubleArray()[i.classIndex()];
    }//w ww  .j a v  a 2s .  co  m
    int cnt = 0;
    for (int itt = 0; itt < 10000; itt++) {
        for (int idxInstance = 0; idxInstance < i.numInstances(); idxInstance++) {
            //buat list input
            ArrayList<Double> listInput = new ArrayList<>();
            listInput.add(1.0); //ini untuk bias
            for (int ins = 0; ins < i.get(idxInstance).numAttributes() - 1; ins++) {
                listInput.add(i.get(idxInstance).value(ins));
            }

            ArrayList<Double> listHide = new ArrayList<>();
            listHide.add(1.0);
            //Hitung output hidden layer
            for (int idxHidden = 1; idxHidden < listHidden.size(); idxHidden++) {
                output(listHidden, listInput, idxHidden);
                listHide.add(listHidden.get(idxHidden).getValue());
            }

            //Hitung ouput output lyer
            for (int idxOutput = 0; idxOutput < listOutput.size(); idxOutput++) {
                output(listOutput, listHide, idxOutput);
            }

            //Hitung error
            calculateError(idxInstance);
            //update bobot
            updateBobot(listInput);
        }
        //Hitung seluruh error untuk menentukan kapan harus berhenti
        //            double error = 0;
        //            for (int idx =0; idx < i.numInstances(); idx++) {
        //                for (int idxOut=0; idxOut < listOutput.size(); idxOut++) {
        //                    error += Math.pow(listOutput.get(idxOut).getError(), 2)/2;
        //                }
        //            }
        //            cnt++;
        //            if (cnt==1000) {
        //                System.out.println("error " + error);
        //                cnt=0;
        //            }
        //            if (error <= threshold) break;
    }
    double error = 0;
    fold++;
    for (int idx = 0; idx < i.numInstances(); idx++) {
        for (int idxOut = 0; idxOut < listOutput.size(); idxOut++) {
            error += Math.pow(listOutput.get(idxOut).getError(), 2) / 2;
        }
    }
    System.out.println("Fold " + fold);
    System.out.println("error " + error);

}

From source file:ANN_single2.SinglelayerPerceptron.java

@Override
public void buildClassifier(Instances i) {
    listOutput = new ArrayList<>();
    for (int idx = 0; idx < i.numClasses(); idx++) {
        listOutput.add(new Node(i.numAttributes()));
    }//  ww  w  .  java  2  s. c o m

    //mengubah class menjadi numeric (diambil indexnya)
    listDoubleinstance = new double[i.numInstances()];
    for (int numIns = 0; numIns < i.numInstances(); numIns++) {
        listDoubleinstance[numIns] = i.instance(numIns).toDoubleArray()[i.classIndex()];
    }

    double error = 0;
    for (int iter = 0; iter < itteration; iter++) {
        double errorThres = 0;
        for (int idxInstance = 0; idxInstance < i.numInstances(); idxInstance++) {

            //buat list input
            ArrayList<Double> listInput = new ArrayList<>();
            listInput.add(1.0); //ini bias
            for (int idx = 0; idx < i.numAttributes() - 1; idx++) {
                listInput.add(i.get(idxInstance).value(idx));
            }

            //Hitung output rumus = sigmoid dari sigma
            for (int idxOut = 0; idxOut < listOutput.size(); idxOut++) {
                output(listInput, idxOut);
            }

            //Hitung error
            calculateError(idxInstance);
            //update bobot
            updateBobot(listInput);

        }
        for (int idxOut = 0; idxOut < listOutput.size(); idxOut++) {
            errorThres += Math.pow(listOutput.get(idxOut).getError(), 2) / 2;
        }
        if (errorThres <= threshold)
            break;
        //            System.out.println(errorThres);
    }
    //        fold++;
    //        for (int idx =0; idx < i.numInstances(); idx++) {
    //            for (int idxOut=0; idxOut < listOutput.size(); idxOut++) {
    //                error += Math.pow(listOutput.get(idxOut).getError(), 2)/2;
    //            }
    //        }
    //        System.out.println("Fold " + fold);
    //        System.out.println("error " + error);
}

From source file:app.RunApp.java

License:Open Source License

/**
 * Preprocess dataset/*from www.j  a  va 2s .  c o  m*/
 * 
 * @return Positive number if successfull and negative otherwise
 */
private int preprocess() {
    trainDatasets = new ArrayList();
    testDatasets = new ArrayList();

    Instances train, test;

    if (dataset == null) {
        JOptionPane.showMessageDialog(null, "You must load a dataset.", "alert", JOptionPane.ERROR_MESSAGE);
        return -1;
    }

    MultiLabelInstances preprocessDataset = dataset.clone();

    if (!radioNoIS.isSelected()) {
        //Do Instance Selection
        if (radioRandomIS.isSelected()) {
            int nInstances = Integer.parseInt(textRandomIS.getText());

            if (nInstances < 1) {
                JOptionPane.showMessageDialog(null,
                        "The number of instances must be a positive natural number.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nInstances > dataset.getNumInstances()) {
                JOptionPane.showMessageDialog(null,
                        "The number of instances to select must be less than the original.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            Instances dataIS;
            try {
                Randomize randomize = new Randomize();
                dataIS = dataset.getDataSet();

                randomize.setInputFormat(dataIS);
                dataIS = Filter.useFilter(dataIS, randomize);
                randomize.batchFinished();

                RemoveRange removeRange = new RemoveRange();
                removeRange.setInputFormat(dataIS);
                removeRange.setInstancesIndices((nInstances + 1) + "-last");

                dataIS = Filter.useFilter(dataIS, removeRange);
                removeRange.batchFinished();

                preprocessDataset = dataset.reintegrateModifiedDataSet(dataIS);
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }

            if (preprocessDataset == null) {
                JOptionPane.showMessageDialog(null, "Error when selecting instances.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            preprocessedDataset = preprocessDataset;
        }
    }

    if (!radioNoFS.isSelected()) {
        //FS_BR
        if (radioBRFS.isSelected()) {
            int nFeatures = Integer.parseInt(textBRFS.getText());
            if (nFeatures < 1) {
                JOptionPane.showMessageDialog(null, "The number of features must be a positive natural number.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFeatures > dataset.getFeatureIndices().length) {
                JOptionPane.showMessageDialog(null,
                        "The number of features to select must be less than the original.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            String combination = jComboBoxBRFSComb.getSelectedItem().toString();
            String normalization = jComboBoxBRFSNorm.getSelectedItem().toString();
            String output = jComboBoxBRFSOut.getSelectedItem().toString();

            FeatureSelector fs;
            if (radioNoIS.isSelected()) {
                fs = new FeatureSelector(dataset, nFeatures);
            } else {
                //If IS have been done
                fs = new FeatureSelector(preprocessDataset, nFeatures);
            }

            preprocessedDataset = fs.select(combination, normalization, output);

            if (preprocessedDataset == null) {
                JOptionPane.showMessageDialog(null, "Error when selecting features.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            preprocessDataset = preprocessedDataset;
        } else if (radioRandomFS.isSelected()) {
            int nFeatures = Integer.parseInt(textRandomFS.getText());

            if (nFeatures < 1) {
                JOptionPane.showMessageDialog(null, "The number of features must be a positive natural number.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFeatures > dataset.getFeatureIndices().length) {
                JOptionPane.showMessageDialog(null,
                        "The number of features to select must be less than the original.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            FeatureSelector fs;

            if (radioNoIS.isSelected()) {
                fs = new FeatureSelector(dataset, nFeatures);
            } else {
                //If IS have been done
                fs = new FeatureSelector(preprocessDataset, nFeatures);
            }

            preprocessedDataset = fs.randomSelect();

            if (preprocessedDataset == null) {
                JOptionPane.showMessageDialog(null, "Error when selecting features.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            preprocessDataset = preprocessedDataset;
        }
    }

    if (!radioNoSplit.isSelected()) {
        //Random Holdout
        if (radioRandomHoldout.isSelected()) {
            String split = textRandomHoldout.getText();
            double percentage = Double.parseDouble(split);
            if ((percentage <= 0) || (percentage >= 100)) {
                JOptionPane.showMessageDialog(null, "The percentage must be a number in the range (0, 100).",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            try {
                RandomTrainTest pre = new RandomTrainTest();
                MultiLabelInstances[] partitions = pre.split(preprocessDataset, percentage);
                trainDataset = partitions[0];
                testDataset = partitions[1];
            } catch (InvalidDataFormatException ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //Random CV
        else if (radioRandomCV.isSelected()) {
            String split = textRandomCV.getText();

            if (split.equals("")) {
                JOptionPane.showMessageDialog(null, "You must enter the number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            int nFolds;

            try {
                nFolds = Integer.parseInt(split);
            } catch (Exception e) {
                JOptionPane.showMessageDialog(null, "Introduce a correct number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            if (nFolds < 2) {
                JOptionPane.showMessageDialog(null, "The number of folds must be greater or equal to 2.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFolds > preprocessDataset.getNumInstances()) {
                JOptionPane.showMessageDialog(null,
                        "The number of folds can not be greater than the number of instances.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            try {
                MultiLabelInstances temp = preprocessDataset.clone();
                Instances dataTemp = temp.getDataSet();

                int seed = (int) (Math.random() * 100) + 100;
                Random rand = new Random(seed);

                dataTemp.randomize(rand);

                Instances[] foldsCV = new Instances[nFolds];
                for (int i = 0; i < nFolds; i++) {
                    foldsCV[i] = new Instances(dataTemp);
                    foldsCV[i].clear();
                }

                for (int i = 0; i < dataTemp.numInstances(); i++) {
                    foldsCV[i % nFolds].add(dataTemp.get(i));
                }

                train = new Instances(dataTemp);
                test = new Instances(dataTemp);
                for (int i = 0; i < nFolds; i++) {
                    train.clear();
                    test.clear();
                    for (int j = 0; j < nFolds; j++) {
                        if (i != j) {
                            System.out.println("Add fold " + j + " to train");
                            train.addAll(foldsCV[j]);
                        }
                    }
                    System.out.println("Add fold " + i + " to test");
                    test.addAll(foldsCV[i]);
                    System.out.println(train.get(0).toString());
                    System.out.println(test.get(0).toString());
                    trainDatasets.add(new MultiLabelInstances(new Instances(train),
                            preprocessDataset.getLabelsMetaData()));
                    testDatasets.add(new MultiLabelInstances(new Instances(test),
                            preprocessDataset.getLabelsMetaData()));
                    System.out.println(trainDatasets.get(i).getDataSet().get(0).toString());
                    System.out.println(testDatasets.get(i).getDataSet().get(0).toString());
                    System.out.println("---");
                }
            }

            catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //Iterative stratified holdout
        else if (radioIterativeStratifiedHoldout.isSelected()) {
            String split = textIterativeStratifiedHoldout.getText();
            double percentage = Double.parseDouble(split);
            if ((percentage <= 0) || (percentage >= 100)) {
                JOptionPane.showMessageDialog(null, "The percentage must be a number in the range (0, 100).",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            try {
                IterativeTrainTest pre = new IterativeTrainTest();
                MultiLabelInstances[] partitions = pre.split(preprocessDataset, percentage);

                trainDataset = partitions[0];
                testDataset = partitions[1];
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //Iterative stratified CV
        else if (radioIterativeStratifiedCV.isSelected()) {
            String split = textIterativeStratifiedCV.getText();

            if (split.equals("")) {
                JOptionPane.showMessageDialog(null, "You must enter the number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            int nFolds = 0;

            try {
                nFolds = Integer.parseInt(split);
            } catch (Exception e) {
                JOptionPane.showMessageDialog(null, "Introduce a correct number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            if (nFolds < 2) {
                JOptionPane.showMessageDialog(null, "The number of folds must be greater or equal to 2.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFolds > preprocessDataset.getNumInstances()) {
                JOptionPane.showMessageDialog(null,
                        "The number of folds can not be greater than the number of instances.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            IterativeStratification strat = new IterativeStratification();
            MultiLabelInstances folds[] = strat.stratify(preprocessDataset, nFolds);

            for (int i = 0; i < nFolds; i++) {
                try {

                    int trainSize = 0, testSize = 0;
                    for (int j = 0; j < nFolds; j++) {
                        if (i != j) {
                            trainSize += folds[j].getNumInstances();
                        }
                    }
                    testSize += folds[i].getNumInstances();

                    train = new Instances(preprocessDataset.getDataSet(), trainSize);
                    test = new Instances(preprocessDataset.getDataSet(), testSize);
                    for (int j = 0; j < nFolds; j++) {
                        if (i != j) {
                            train.addAll(folds[j].getDataSet());
                        }
                    }
                    test.addAll(folds[i].getDataSet());

                    trainDatasets.add(new MultiLabelInstances(train, preprocessDataset.getLabelsMetaData()));
                    testDatasets.add(new MultiLabelInstances(test, preprocessDataset.getLabelsMetaData()));
                } catch (InvalidDataFormatException ex) {
                    Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
                }
            }

        }
        //LP stratified holdout
        else if (radioLPStratifiedHoldout.isSelected()) {
            String split = textLPStratifiedHoldout.getText();
            double percentage = Double.parseDouble(split);
            if ((percentage <= 0) || (percentage >= 100)) {
                JOptionPane.showMessageDialog(null, "The percentage must be a number in the range (0, 100).",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            try {
                IterativeTrainTest pre = new IterativeTrainTest();
                MultiLabelInstances[] partitions = pre.split(preprocessDataset, percentage);

                trainDataset = partitions[0];
                testDataset = partitions[1];
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //LP stratified CV
        else if (radioLPStratifiedCV.isSelected()) {
            String split = textLPStratifiedCV.getText();

            if (split.equals("")) {
                JOptionPane.showMessageDialog(null, "You must enter the number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            int nFolds = 0;

            try {
                nFolds = Integer.parseInt(split);
            } catch (Exception e) {
                JOptionPane.showMessageDialog(null, "Introduce a correct number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            if (nFolds < 2) {
                JOptionPane.showMessageDialog(null, "The number of folds must be greater or equal to 2.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFolds > preprocessDataset.getNumInstances()) {
                JOptionPane.showMessageDialog(null,
                        "The number of folds can not be greater than the number of instances.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            LabelPowersetTrainTest strat = new LabelPowersetTrainTest();
            MultiLabelInstances folds[] = strat.stratify(preprocessDataset, nFolds);

            for (int i = 0; i < nFolds; i++) {
                try {
                    train = new Instances(preprocessDataset.getDataSet(), 0);
                    test = new Instances(preprocessDataset.getDataSet(), 0);

                    for (int j = 0; j < nFolds; j++) {
                        if (i != j) {
                            train.addAll(folds[j].getDataSet());
                        }
                    }
                    test.addAll(folds[i].getDataSet());

                    trainDatasets.add(new MultiLabelInstances(train, preprocessDataset.getLabelsMetaData()));
                    testDatasets.add(new MultiLabelInstances(test, preprocessDataset.getLabelsMetaData()));
                } catch (InvalidDataFormatException ex) {
                    Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
                }
            }
        }
    }

    jButtonSaveDatasets.setEnabled(true);
    jComboBoxSaveFormat.setEnabled(true);

    return 1;
}

From source file:asap.PostProcess.java

private void formatPredictions(Instances instances, double[] predictions, String[] columnNames,
        int predictionsColumnIndex, String predictionsColumnName, String columnSeparator, String outputFilename,
        boolean writeColumnsHeaderLine) {
    PerformanceCounters.startTimer("formatPredictions");

    System.out.println("Formatting predictions to file " + outputFilename + "...");
    File outputFile = new File(outputFilename);
    PrintWriter writer;//from  w  ww.  j a  va 2 s. c  o m

    try {
        outputFile.getParentFile().mkdirs();
        outputFile.createNewFile();
        writer = new PrintWriter(outputFile, "UTF-8");
    } catch (IOException ex) {
        Logger.getLogger(PostProcess.class.getName()).log(Level.SEVERE, null, ex);
        return;
    }

    StringBuilder sb = new StringBuilder();
    DecimalFormat df = new DecimalFormat("#.#", new DecimalFormatSymbols(Locale.US));
    df.setMaximumFractionDigits(3);

    int i = -1;
    if (!writeColumnsHeaderLine) {
        i = 0;
    }
    for (; i < instances.numInstances(); i++) {
        sb.delete(0, sb.length());

        for (int j = 0; j < columnNames.length; j++) {
            if (j > 0) {
                sb.append(columnSeparator);
            }

            if (j == predictionsColumnIndex) {
                if (i < 0) {
                    sb.append(predictionsColumnName);
                } else {
                    sb.append(df.format(predictions[i]));
                }
                sb.append(columnSeparator);
            }
            if (i < 0) {
                sb.append(columnNames[j]);
            } else {
                if (columnNames[j].toLowerCase().contains("id")) {
                    Attribute attribute = instances.attribute(columnNames[j]);
                    if (attribute != null) {
                        sb.append((int) instances.instance(i).value(attribute.index()));
                    } else {
                        sb.append(0);
                    }
                } else {
                    Attribute attribute = instances.attribute(columnNames[j]);
                    if (attribute != null) {
                        sb.append(instances.instance(i).value(attribute.index()));
                    } else {
                        sb.append(df.format(0d));
                    }
                }
            }
        }

        if (columnNames.length == predictionsColumnIndex) {
            sb.append(columnSeparator);
            if (i < 0) {
                sb.append(predictionsColumnName);
            } else {
                sb.append(df.format(predictions[i]));
            }
        }

        writer.println(sb);
    }
    writer.flush();
    writer.close();
    System.out.println("\tdone.");
    PerformanceCounters.stopTimer("formatPredictions");
}