Example usage for weka.filters.unsupervised.instance Randomize Randomize

List of usage examples for weka.filters.unsupervised.instance Randomize Randomize

Introduction

In this page you can find the example usage for weka.filters.unsupervised.instance Randomize Randomize.

Prototype

Randomize

Source Link

Usage

From source file:app.RunApp.java

License:Open Source License

/**
 * Preprocess dataset/*from w w w  .java2  s .co m*/
 * 
 * @return Positive number if successfull and negative otherwise
 */
private int preprocess() {
    trainDatasets = new ArrayList();
    testDatasets = new ArrayList();

    Instances train, test;

    if (dataset == null) {
        JOptionPane.showMessageDialog(null, "You must load a dataset.", "alert", JOptionPane.ERROR_MESSAGE);
        return -1;
    }

    MultiLabelInstances preprocessDataset = dataset.clone();

    if (!radioNoIS.isSelected()) {
        //Do Instance Selection
        if (radioRandomIS.isSelected()) {
            int nInstances = Integer.parseInt(textRandomIS.getText());

            if (nInstances < 1) {
                JOptionPane.showMessageDialog(null,
                        "The number of instances must be a positive natural number.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nInstances > dataset.getNumInstances()) {
                JOptionPane.showMessageDialog(null,
                        "The number of instances to select must be less than the original.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            Instances dataIS;
            try {
                Randomize randomize = new Randomize();
                dataIS = dataset.getDataSet();

                randomize.setInputFormat(dataIS);
                dataIS = Filter.useFilter(dataIS, randomize);
                randomize.batchFinished();

                RemoveRange removeRange = new RemoveRange();
                removeRange.setInputFormat(dataIS);
                removeRange.setInstancesIndices((nInstances + 1) + "-last");

                dataIS = Filter.useFilter(dataIS, removeRange);
                removeRange.batchFinished();

                preprocessDataset = dataset.reintegrateModifiedDataSet(dataIS);
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }

            if (preprocessDataset == null) {
                JOptionPane.showMessageDialog(null, "Error when selecting instances.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            preprocessedDataset = preprocessDataset;
        }
    }

    if (!radioNoFS.isSelected()) {
        //FS_BR
        if (radioBRFS.isSelected()) {
            int nFeatures = Integer.parseInt(textBRFS.getText());
            if (nFeatures < 1) {
                JOptionPane.showMessageDialog(null, "The number of features must be a positive natural number.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFeatures > dataset.getFeatureIndices().length) {
                JOptionPane.showMessageDialog(null,
                        "The number of features to select must be less than the original.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            String combination = jComboBoxBRFSComb.getSelectedItem().toString();
            String normalization = jComboBoxBRFSNorm.getSelectedItem().toString();
            String output = jComboBoxBRFSOut.getSelectedItem().toString();

            FeatureSelector fs;
            if (radioNoIS.isSelected()) {
                fs = new FeatureSelector(dataset, nFeatures);
            } else {
                //If IS have been done
                fs = new FeatureSelector(preprocessDataset, nFeatures);
            }

            preprocessedDataset = fs.select(combination, normalization, output);

            if (preprocessedDataset == null) {
                JOptionPane.showMessageDialog(null, "Error when selecting features.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            preprocessDataset = preprocessedDataset;
        } else if (radioRandomFS.isSelected()) {
            int nFeatures = Integer.parseInt(textRandomFS.getText());

            if (nFeatures < 1) {
                JOptionPane.showMessageDialog(null, "The number of features must be a positive natural number.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFeatures > dataset.getFeatureIndices().length) {
                JOptionPane.showMessageDialog(null,
                        "The number of features to select must be less than the original.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            FeatureSelector fs;

            if (radioNoIS.isSelected()) {
                fs = new FeatureSelector(dataset, nFeatures);
            } else {
                //If IS have been done
                fs = new FeatureSelector(preprocessDataset, nFeatures);
            }

            preprocessedDataset = fs.randomSelect();

            if (preprocessedDataset == null) {
                JOptionPane.showMessageDialog(null, "Error when selecting features.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            preprocessDataset = preprocessedDataset;
        }
    }

    if (!radioNoSplit.isSelected()) {
        //Random Holdout
        if (radioRandomHoldout.isSelected()) {
            String split = textRandomHoldout.getText();
            double percentage = Double.parseDouble(split);
            if ((percentage <= 0) || (percentage >= 100)) {
                JOptionPane.showMessageDialog(null, "The percentage must be a number in the range (0, 100).",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            try {
                RandomTrainTest pre = new RandomTrainTest();
                MultiLabelInstances[] partitions = pre.split(preprocessDataset, percentage);
                trainDataset = partitions[0];
                testDataset = partitions[1];
            } catch (InvalidDataFormatException ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //Random CV
        else if (radioRandomCV.isSelected()) {
            String split = textRandomCV.getText();

            if (split.equals("")) {
                JOptionPane.showMessageDialog(null, "You must enter the number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            int nFolds;

            try {
                nFolds = Integer.parseInt(split);
            } catch (Exception e) {
                JOptionPane.showMessageDialog(null, "Introduce a correct number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            if (nFolds < 2) {
                JOptionPane.showMessageDialog(null, "The number of folds must be greater or equal to 2.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFolds > preprocessDataset.getNumInstances()) {
                JOptionPane.showMessageDialog(null,
                        "The number of folds can not be greater than the number of instances.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            try {
                MultiLabelInstances temp = preprocessDataset.clone();
                Instances dataTemp = temp.getDataSet();

                int seed = (int) (Math.random() * 100) + 100;
                Random rand = new Random(seed);

                dataTemp.randomize(rand);

                Instances[] foldsCV = new Instances[nFolds];
                for (int i = 0; i < nFolds; i++) {
                    foldsCV[i] = new Instances(dataTemp);
                    foldsCV[i].clear();
                }

                for (int i = 0; i < dataTemp.numInstances(); i++) {
                    foldsCV[i % nFolds].add(dataTemp.get(i));
                }

                train = new Instances(dataTemp);
                test = new Instances(dataTemp);
                for (int i = 0; i < nFolds; i++) {
                    train.clear();
                    test.clear();
                    for (int j = 0; j < nFolds; j++) {
                        if (i != j) {
                            System.out.println("Add fold " + j + " to train");
                            train.addAll(foldsCV[j]);
                        }
                    }
                    System.out.println("Add fold " + i + " to test");
                    test.addAll(foldsCV[i]);
                    System.out.println(train.get(0).toString());
                    System.out.println(test.get(0).toString());
                    trainDatasets.add(new MultiLabelInstances(new Instances(train),
                            preprocessDataset.getLabelsMetaData()));
                    testDatasets.add(new MultiLabelInstances(new Instances(test),
                            preprocessDataset.getLabelsMetaData()));
                    System.out.println(trainDatasets.get(i).getDataSet().get(0).toString());
                    System.out.println(testDatasets.get(i).getDataSet().get(0).toString());
                    System.out.println("---");
                }
            }

            catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //Iterative stratified holdout
        else if (radioIterativeStratifiedHoldout.isSelected()) {
            String split = textIterativeStratifiedHoldout.getText();
            double percentage = Double.parseDouble(split);
            if ((percentage <= 0) || (percentage >= 100)) {
                JOptionPane.showMessageDialog(null, "The percentage must be a number in the range (0, 100).",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            try {
                IterativeTrainTest pre = new IterativeTrainTest();
                MultiLabelInstances[] partitions = pre.split(preprocessDataset, percentage);

                trainDataset = partitions[0];
                testDataset = partitions[1];
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //Iterative stratified CV
        else if (radioIterativeStratifiedCV.isSelected()) {
            String split = textIterativeStratifiedCV.getText();

            if (split.equals("")) {
                JOptionPane.showMessageDialog(null, "You must enter the number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            int nFolds = 0;

            try {
                nFolds = Integer.parseInt(split);
            } catch (Exception e) {
                JOptionPane.showMessageDialog(null, "Introduce a correct number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            if (nFolds < 2) {
                JOptionPane.showMessageDialog(null, "The number of folds must be greater or equal to 2.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFolds > preprocessDataset.getNumInstances()) {
                JOptionPane.showMessageDialog(null,
                        "The number of folds can not be greater than the number of instances.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            IterativeStratification strat = new IterativeStratification();
            MultiLabelInstances folds[] = strat.stratify(preprocessDataset, nFolds);

            for (int i = 0; i < nFolds; i++) {
                try {

                    int trainSize = 0, testSize = 0;
                    for (int j = 0; j < nFolds; j++) {
                        if (i != j) {
                            trainSize += folds[j].getNumInstances();
                        }
                    }
                    testSize += folds[i].getNumInstances();

                    train = new Instances(preprocessDataset.getDataSet(), trainSize);
                    test = new Instances(preprocessDataset.getDataSet(), testSize);
                    for (int j = 0; j < nFolds; j++) {
                        if (i != j) {
                            train.addAll(folds[j].getDataSet());
                        }
                    }
                    test.addAll(folds[i].getDataSet());

                    trainDatasets.add(new MultiLabelInstances(train, preprocessDataset.getLabelsMetaData()));
                    testDatasets.add(new MultiLabelInstances(test, preprocessDataset.getLabelsMetaData()));
                } catch (InvalidDataFormatException ex) {
                    Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
                }
            }

        }
        //LP stratified holdout
        else if (radioLPStratifiedHoldout.isSelected()) {
            String split = textLPStratifiedHoldout.getText();
            double percentage = Double.parseDouble(split);
            if ((percentage <= 0) || (percentage >= 100)) {
                JOptionPane.showMessageDialog(null, "The percentage must be a number in the range (0, 100).",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            try {
                IterativeTrainTest pre = new IterativeTrainTest();
                MultiLabelInstances[] partitions = pre.split(preprocessDataset, percentage);

                trainDataset = partitions[0];
                testDataset = partitions[1];
            } catch (Exception ex) {
                Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        //LP stratified CV
        else if (radioLPStratifiedCV.isSelected()) {
            String split = textLPStratifiedCV.getText();

            if (split.equals("")) {
                JOptionPane.showMessageDialog(null, "You must enter the number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            int nFolds = 0;

            try {
                nFolds = Integer.parseInt(split);
            } catch (Exception e) {
                JOptionPane.showMessageDialog(null, "Introduce a correct number of folds.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            if (nFolds < 2) {
                JOptionPane.showMessageDialog(null, "The number of folds must be greater or equal to 2.",
                        "alert", JOptionPane.ERROR_MESSAGE);
                return -1;
            } else if (nFolds > preprocessDataset.getNumInstances()) {
                JOptionPane.showMessageDialog(null,
                        "The number of folds can not be greater than the number of instances.", "alert",
                        JOptionPane.ERROR_MESSAGE);
                return -1;
            }

            LabelPowersetTrainTest strat = new LabelPowersetTrainTest();
            MultiLabelInstances folds[] = strat.stratify(preprocessDataset, nFolds);

            for (int i = 0; i < nFolds; i++) {
                try {
                    train = new Instances(preprocessDataset.getDataSet(), 0);
                    test = new Instances(preprocessDataset.getDataSet(), 0);

                    for (int j = 0; j < nFolds; j++) {
                        if (i != j) {
                            train.addAll(folds[j].getDataSet());
                        }
                    }
                    test.addAll(folds[i].getDataSet());

                    trainDatasets.add(new MultiLabelInstances(train, preprocessDataset.getLabelsMetaData()));
                    testDatasets.add(new MultiLabelInstances(test, preprocessDataset.getLabelsMetaData()));
                } catch (InvalidDataFormatException ex) {
                    Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex);
                }
            }
        }
    }

    jButtonSaveDatasets.setEnabled(true);
    jComboBoxSaveFormat.setEnabled(true);

    return 1;
}

From source file:clases.Preproceso.java

public static Instances randomize(Instances data) throws Exception {
    Randomize ran = new Randomize();
    ran.setRandomSeed(42);//w w  w.j a  v a  2  s  .c o  m
    ran.setInputFormat(data);
    return Filter.useFilter(data, ran);
}