Example usage for weka.core Instances randomize

List of usage examples for weka.core Instances randomize

Introduction

In this page you can find the example usage for weka.core Instances randomize.

Prototype

public void randomize(Random random) 

Source Link

Document

Shuffles the instances in the set so that they are ordered randomly.

Usage

From source file:org.esa.nest.gpf.SGD.java

/**
 * Method for building the classifier.//  www  . jav  a2  s.  c  o  m
 *
 * @param data the set of training instances.
 * @throws Exception if the classifier can't be built successfully.
 */
@Override
public void buildClassifier(Instances data) throws Exception {
    reset();

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    data = new Instances(data);
    data.deleteWithMissingClass();

    if (data.numInstances() > 0 && !m_dontReplaceMissing) {
        m_replaceMissing = new ReplaceMissingValues();
        m_replaceMissing.setInputFormat(data);
        data = Filter.useFilter(data, m_replaceMissing);
    }

    // check for only numeric attributes
    boolean onlyNumeric = true;
    for (int i = 0; i < data.numAttributes(); i++) {
        if (i != data.classIndex()) {
            if (!data.attribute(i).isNumeric()) {
                onlyNumeric = false;
                break;
            }
        }
    }

    if (!onlyNumeric) {
        if (data.numInstances() > 0) {
            m_nominalToBinary = new weka.filters.supervised.attribute.NominalToBinary();
        } else {
            m_nominalToBinary = new weka.filters.unsupervised.attribute.NominalToBinary();
        }
        m_nominalToBinary.setInputFormat(data);
        data = Filter.useFilter(data, m_nominalToBinary);
    }

    if (!m_dontNormalize && data.numInstances() > 0) {

        m_normalize = new Normalize();
        m_normalize.setInputFormat(data);
        data = Filter.useFilter(data, m_normalize);
    }

    m_numInstances = data.numInstances();

    m_weights = new double[data.numAttributes() + 1];
    m_data = new Instances(data, 0);

    if (data.numInstances() > 0) {
        data.randomize(new Random(getSeed())); // randomize the data
        train(data);
    }
}

From source file:org.scripps.branch.classifier.ManualTree.java

License:Open Source License

/**
 * Builds classifier./*from  w  ww.j  a v  a2  s . c  o  m*/
 * 
 * @param data
 *            the data to train with
 * @throws Exception
 *             if something goes wrong or the data doesn't fit
 */
@Override
public void buildClassifier(Instances data) throws Exception {
    // Make sure K value is in range
    if (m_KValue > data.numAttributes() - 1)
        m_KValue = data.numAttributes() - 1;
    if (m_KValue < 1)
        m_KValue = (int) Utils.log2(data.numAttributes()) + 1;

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    // only class? -> build ZeroR model
    if (data.numAttributes() == 1) {
        System.err.println(
                "Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!");
        m_ZeroR = new weka.classifiers.rules.ZeroR();
        m_ZeroR.buildClassifier(data);
        return;
    } else {
        m_ZeroR = null;
    }

    // Figure out appropriate datasets
    Instances train = null;
    Instances backfit = null;
    Random rand = data.getRandomNumberGenerator(m_randomSeed);
    if (m_NumFolds <= 0) {
        train = data;
    } else {
        data.randomize(rand);
        data.stratify(m_NumFolds);
        train = data.trainCV(m_NumFolds, 1, rand);
        backfit = data.testCV(m_NumFolds, 1);
    }

    //Set Default Instances for selection.
    setRequiredInst(data);

    // Create the attribute indices window
    int[] attIndicesWindow = new int[data.numAttributes() - 1];
    int j = 0;
    for (int i = 0; i < attIndicesWindow.length; i++) {
        if (j == data.classIndex())
            j++; // do not include the class
        attIndicesWindow[i] = j++;
    }

    // Compute initial class counts
    double[] classProbs = new double[train.numClasses()];
    for (int i = 0; i < train.numInstances(); i++) {
        Instance inst = train.instance(i);
        classProbs[(int) inst.classValue()] += inst.weight();
    }

    Instances requiredInstances = getRequiredInst();
    // Build tree
    if (jsontree != null) {
        buildTree(train, classProbs, new Instances(data, 0), m_Debug, 0, jsontree, 0, m_distributionData,
                requiredInstances, listOfFc, cSetList, ccSer, d);
    } else {
        System.out.println("No json tree specified, failing to process tree");
    }
    setRequiredInst(requiredInstances);
    // Backfit if required
    if (backfit != null) {
        backfitData(backfit);
    }
}

From source file:predictors.HelixIndexer.java

License:Open Source License

/**
 * Trains the Weka Classifer.//  ww  w.j  a v  a 2s .  c  o  m
 */
public void trainClassifier() {
    try {
        RandomForest classifier = new weka.classifiers.trees.RandomForest();
        Instances data = this.dataset;

        if (data.classIndex() == -1) {
            data.setClassIndex(data.numAttributes() - 1);
        }

        data.randomize(new Random(data.size()));

        String[] optClassifier = weka.core.Utils.splitOptions("-I 100 -K 9 -S 1 -num-slots 3");

        classifier.setOptions(optClassifier);
        classifier.setSeed(data.size());

        classifier.buildClassifier(data);

        this.classifier = classifier;
        this.isTrained = true;
    } catch (Exception e) {
        ErrorUtils.printError(HelixIndexer.class, "Training failed", e);
    }
}

From source file:predictors.HelixPredictor.java

License:Open Source License

/**
 * Trains the Weka Classifer.//from  w  w  w. j  ava 2  s. com
 */
public void trainClassifier() {
    try {
        MultilayerPerceptron classifier = new weka.classifiers.functions.MultilayerPerceptron();
        Instances data = this.dataset;

        if (data.classIndex() == -1) {
            data.setClassIndex(data.numAttributes() - 1);
        }

        data.randomize(new Random(data.size()));

        String[] optClassifier = weka.core.Utils
                .splitOptions("-L 0.01 -M 0.8 -N 256 -V 20 -S 0 -E 5 -H 25 -B -I -D -C");

        classifier.setOptions(optClassifier);
        classifier.setSeed(data.size());

        classifier.buildClassifier(data);

        this.classifier = classifier;
        this.isTrained = true;
    } catch (Exception e) {
        ErrorUtils.printError(HelixPredictor.class, "Training failed", e);
    }
}

From source file:predictors.TopologyPredictor.java

License:Open Source License

/**
 * Trains the Weka Classifer.//from   w  w  w.j av  a 2  s  .c o m
 */
public void trainClassifier() {
    try {
        RandomForest classifier = new weka.classifiers.trees.RandomForest();
        Instances data = this.dataset;

        if (data.classIndex() == -1) {
            data.setClassIndex(data.numAttributes() - 1);
        }

        data.randomize(new Random(data.size()));

        String[] optClassifier = weka.core.Utils.splitOptions("-I 100 -K 7 -S 1 -num-slots 1");

        classifier.setOptions(optClassifier);
        classifier.setSeed(data.size());

        classifier.buildClassifier(data);

        this.classifier = classifier;
        this.isTrained = true;
    } catch (Exception e) {
        ErrorUtils.printError(TopologyPredictor.class, "Training failed", e);
    }
}

From source file:sentinets.Prediction.java

License:Open Source License

public String updateModel(String inputFile, ArrayList<Double[]> metrics) {
    String output = "";
    this.setInstances(inputFile);
    FilteredClassifier fcls = (FilteredClassifier) this.cls;
    SGD cls = (SGD) fcls.getClassifier();
    Filter filter = fcls.getFilter();
    Instances insAll;//  w ww .ja v  a  2  s .c om
    try {
        insAll = Filter.useFilter(this.unlabled, filter);
        if (insAll.size() > 0) {
            Random rand = new Random(10);
            int folds = 10 > insAll.size() ? 2 : 10;
            Instances randData = new Instances(insAll);
            randData.randomize(rand);
            if (randData.classAttribute().isNominal()) {
                randData.stratify(folds);
            }
            Evaluation eval = new Evaluation(randData);
            eval.evaluateModel(cls, insAll);
            System.out.println("Initial Evaluation");
            System.out.println(eval.toSummaryString());
            System.out.println(eval.toClassDetailsString());
            metrics.add(new Double[] { eval.fMeasure(0), eval.fMeasure(1), eval.weightedFMeasure() });
            output += "\n====" + "Initial Evaluation" + "====\n";
            output += "\n" + eval.toSummaryString();
            output += "\n" + eval.toClassDetailsString();
            System.out.println("Cross Validated Evaluation");
            output += "\n====" + "Cross Validated Evaluation" + "====\n";
            for (int n = 0; n < folds; n++) {
                Instances train = randData.trainCV(folds, n);
                Instances test = randData.testCV(folds, n);

                for (int i = 0; i < train.numInstances(); i++) {
                    cls.updateClassifier(train.instance(i));
                }

                eval.evaluateModel(cls, test);
                System.out.println("Cross Validated Evaluation fold: " + n);
                output += "\n====" + "Cross Validated Evaluation fold (" + n + ")====\n";
                System.out.println(eval.toSummaryString());
                System.out.println(eval.toClassDetailsString());
                output += "\n" + eval.toSummaryString();
                output += "\n" + eval.toClassDetailsString();
                metrics.add(new Double[] { eval.fMeasure(0), eval.fMeasure(1), eval.weightedFMeasure() });
            }
            for (int i = 0; i < insAll.numInstances(); i++) {
                cls.updateClassifier(insAll.instance(i));
            }
            eval.evaluateModel(cls, insAll);
            System.out.println("Final Evaluation");
            System.out.println(eval.toSummaryString());
            System.out.println(eval.toClassDetailsString());
            output += "\n====" + "Final Evaluation" + "====\n";
            output += "\n" + eval.toSummaryString();
            output += "\n" + eval.toClassDetailsString();
            metrics.add(new Double[] { eval.fMeasure(0), eval.fMeasure(1), eval.weightedFMeasure() });
            fcls.setClassifier(cls);
            String modelFilePath = outputDir + "/" + Utils.getOutDir(Utils.OutDirIndex.MODELS)
                    + "/updatedClassifier.model";
            weka.core.SerializationHelper.write(modelFilePath, fcls);
            output += "\n" + "Updated Model saved at: " + modelFilePath;
        } else {
            output += "No new instances for training the model.";
        }
    } catch (Exception e) {
        e.printStackTrace();
    }
    return output;
}

From source file:sirius.clustering.main.ClustererClassificationPane.java

License:Open Source License

private void start() {
    //Run Classifier
    if (this.inputDirectoryTextField.getText().length() == 0) {
        JOptionPane.showMessageDialog(parent, "Please set Input Directory to where the clusterer output are!",
                "Evaluate Classifier", JOptionPane.ERROR_MESSAGE);
        return;/*from w w w .j  a  v a  2 s.  c  o  m*/
    }
    if (m_ClassifierEditor.getValue() == null) {
        JOptionPane.showMessageDialog(parent, "Please choose Classifier!", "Evaluate Classifier",
                JOptionPane.ERROR_MESSAGE);
        return;
    }
    if (validateStatsSettings(1) == false) {
        return;
    }
    if (this.clusteringClassificationThread == null) {
        startButton.setEnabled(false);
        stopButton.setEnabled(true);
        tabbedClassifierPane.setSelectedIndex(0);
        this.clusteringClassificationThread = (new Thread() {
            public void run() {
                //Clear the output text area
                levelOneClassifierOutputTextArea.setText("");
                resultsTableModel.reset();
                //double threshold = Double.parseDouble(classifierOneThresholdTextField.getText());                
                //cross-validation
                int numFolds;
                if (jackKnifeRadioButton.isSelected())
                    numFolds = -1;
                else
                    numFolds = Integer.parseInt(foldsField.getText());
                StringTokenizer st = new StringTokenizer(inputDirectoryTextField.getText(), File.separator);
                String filename = "";
                while (st.hasMoreTokens()) {
                    filename = st.nextToken();
                }
                StringTokenizer st2 = new StringTokenizer(filename, "_.");
                numOfCluster = 0;
                if (st2.countTokens() >= 2) {
                    st2.nextToken();
                    String numOfClusterString = st2.nextToken().replaceAll("cluster", "");
                    try {
                        numOfCluster = Integer.parseInt(numOfClusterString);
                    } catch (NumberFormatException e) {
                        JOptionPane.showMessageDialog(parent,
                                "Please choose the correct file! (Output from Utilize Clusterer)", "ERROR",
                                JOptionPane.ERROR_MESSAGE);
                    }
                }
                Classifier template = (Classifier) m_ClassifierEditor.getValue();

                for (int x = 0; x <= numOfCluster && clusteringClassificationThread != null; x++) {//Test each cluster
                    try {
                        long totalTimeStart = 0, totalTimeElapsed = 0;
                        totalTimeStart = System.currentTimeMillis();
                        statusLabel.setText("Reading in cluster" + x + " file..");
                        String inputFilename = inputDirectoryTextField.getText()
                                .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".arff");
                        String outputScoreFilename = inputDirectoryTextField.getText()
                                .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".score");
                        BufferedWriter output = new BufferedWriter(new FileWriter(outputScoreFilename));
                        Instances inst = new Instances(new FileReader(inputFilename));
                        //Assume that class attribute is the last attribute - This should be the case for all Sirius produced Arff files
                        inst.setClassIndex(inst.numAttributes() - 1);
                        Random random = new Random(1);//Simply set to 1, shall implement the random seed option later
                        inst.randomize(random);
                        if (inst.attribute(inst.classIndex()).isNominal())
                            inst.stratify(numFolds);
                        // for timing
                        ClassifierResults classifierResults = new ClassifierResults(false, 0);
                        String classifierName = m_ClassifierEditor.getValue().getClass().getName();
                        classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ",
                                classifierName);
                        classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ",
                                inputFilename);
                        classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ",
                                "NA");
                        //ArrayList<Double> resultList = new ArrayList<Double>();     
                        if (jackKnifeRadioButton.isSelected() || numFolds > inst.numInstances() - 1)
                            numFolds = inst.numInstances() - 1;
                        for (int fold = 0; fold < numFolds && clusteringClassificationThread != null; fold++) {//Doing cross-validation          
                            statusLabel.setText("Cluster: " + x + " - Training Fold " + (fold + 1) + "..");
                            Instances train = inst.trainCV(numFolds, fold, random);
                            Classifier current = null;
                            try {
                                current = Classifier.makeCopy(template);
                                current.buildClassifier(train);
                                Instances test = inst.testCV(numFolds, fold);
                                statusLabel.setText("Cluster: " + x + " - Testing Fold " + (fold + 1) + "..");
                                for (int jj = 0; jj < test.numInstances(); jj++) {
                                    double[] result = current.distributionForInstance(test.instance(jj));
                                    output.write("Cluster: " + x);
                                    output.newLine();
                                    output.newLine();
                                    output.write(test.instance(jj).stringValue(test.classAttribute()) + ",0="
                                            + result[0]);
                                    output.newLine();
                                }
                            } catch (Exception ex) {
                                ex.printStackTrace();
                                statusLabel.setText("Error in cross-validation!");
                                startButton.setEnabled(true);
                                stopButton.setEnabled(false);
                            }
                        }
                        output.close();
                        totalTimeElapsed = System.currentTimeMillis() - totalTimeStart;
                        classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ",
                                Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes "
                                        + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2)
                                        + " seconds");
                        double threshold = validateFieldAsThreshold(classifierOneThresholdTextField.getText(),
                                "Threshold Field", classifierOneThresholdTextField);
                        String filename2 = inputDirectoryTextField.getText()
                                .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".score");
                        PredictionStats classifierStats = new PredictionStats(filename2, 0, threshold);

                        resultsTableModel.add("Cluster " + x, classifierResults, classifierStats);
                        resultsTable.setRowSelectionInterval(x, x);
                        computeStats(numFolds);//compute and display the results                    
                    } catch (Exception e) {
                        e.printStackTrace();
                        statusLabel.setText("Error in reading file!");
                        startButton.setEnabled(true);
                        stopButton.setEnabled(false);
                    }
                } //end of cluster for loop

                resultsTableModel.add("Summary - Equal Weightage", null, null);
                resultsTable.setRowSelectionInterval(numOfCluster + 1, numOfCluster + 1);
                computeStats(numFolds);
                resultsTableModel.add("Summary - Weighted Average", null, null);
                resultsTable.setRowSelectionInterval(numOfCluster + 2, numOfCluster + 2);
                computeStats(numFolds);

                if (clusteringClassificationThread != null)
                    statusLabel.setText("Done!");
                else
                    statusLabel.setText("Interrupted..");
                startButton.setEnabled(true);
                stopButton.setEnabled(false);
                if (classifierOne != null) {
                    levelOneClassifierOutputScrollPane.getVerticalScrollBar()
                            .setValue(levelOneClassifierOutputScrollPane.getVerticalScrollBar().getMaximum());
                }
                clusteringClassificationThread = null;

            }
        });
        this.clusteringClassificationThread.setPriority(Thread.MIN_PRIORITY);
        this.clusteringClassificationThread.start();
    } else {
        JOptionPane.showMessageDialog(parent,
                "Cannot start new job as previous job still running. Click stop to terminate previous job",
                "ERROR", JOptionPane.ERROR_MESSAGE);
    }
}

From source file:smo2.SMO.java

License:Open Source License

/**
 * Method for building the classifier. Implements a one-against-one wrapper
 * for multi-class problems.//from w  w  w  .  ja  va  2  s.  c  o  m
 *
 * @param insts
 *            the set of training instances
 * @exception Exception
 *                if the classifier can't be built successfully
 */
public void buildClassifier(Instances insts) throws Exception {

    if (!m_checksTurnedOff) {
        if (insts.checkForStringAttributes()) {
            throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
        }
        if (insts.classAttribute().isNumeric()) {
            throw new UnsupportedClassTypeException(
                    "mySMO can't handle a numeric class! Use" + "SMOreg for performing regression.");
        }
        insts = new Instances(insts);
        insts.deleteWithMissingClass();
        if (insts.numInstances() == 0) {
            throw new Exception("No training instances without a missing class!");
        }

        /*
         * Removes all the instances with weight equal to 0. MUST be done
         * since condition (8) of Keerthi's paper is made with the assertion
         * Ci > 0 (See equation (3a).
         */
        Instances data = new Instances(insts, insts.numInstances());
        for (int i = 0; i < insts.numInstances(); i++) {
            if (insts.instance(i).weight() > 0)
                data.add(insts.instance(i));
        }
        if (data.numInstances() == 0) {
            throw new Exception("No training instances left after removing "
                    + "instance with either a weight null or a missing class!");
        }
        insts = data;

    }

    m_onlyNumeric = true;
    if (!m_checksTurnedOff) {
        for (int i = 0; i < insts.numAttributes(); i++) {
            if (i != insts.classIndex()) {
                if (!insts.attribute(i).isNumeric()) {
                    m_onlyNumeric = false;
                    break;
                }
            }
        }
    }

    if (!m_checksTurnedOff) {
        m_Missing = new ReplaceMissingValues();
        m_Missing.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_Missing);
    } else {
        m_Missing = null;
    }

    if (!m_onlyNumeric) {
        m_NominalToBinary = new NominalToBinary();
        m_NominalToBinary.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_NominalToBinary);
    } else {
        m_NominalToBinary = null;
    }

    if (m_filterType == FILTER_STANDARDIZE) {
        m_Filter = new Standardize();
        m_Filter.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_Filter);
    } else if (m_filterType == FILTER_NORMALIZE) {
        m_Filter = new Normalize();
        m_Filter.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_Filter);
    } else {
        m_Filter = null;
    }

    m_classIndex = insts.classIndex();
    m_classAttribute = insts.classAttribute();

    // Generate subsets representing each class
    Instances[] subsets = new Instances[insts.numClasses()];
    for (int i = 0; i < insts.numClasses(); i++) {
        subsets[i] = new Instances(insts, insts.numInstances());
    }
    for (int j = 0; j < insts.numInstances(); j++) {
        Instance inst = insts.instance(j);
        subsets[(int) inst.classValue()].add(inst);
    }
    for (int i = 0; i < insts.numClasses(); i++) {
        subsets[i].compactify();
    }

    // Build the binary classifiers
    Random rand = new Random(m_randomSeed);
    m_classifiers = new BinarymySMO[insts.numClasses()][insts.numClasses()];
    for (int i = 0; i < insts.numClasses(); i++) {
        for (int j = i + 1; j < insts.numClasses(); j++) {
            m_classifiers[i][j] = new BinarymySMO();
            Instances data = new Instances(insts, insts.numInstances());
            for (int k = 0; k < subsets[i].numInstances(); k++) {
                data.add(subsets[i].instance(k));
            }
            for (int k = 0; k < subsets[j].numInstances(); k++) {
                data.add(subsets[j].instance(k));
            }
            data.compactify();
            data.randomize(rand);
            m_classifiers[i][j].buildClassifier(data, i, j, m_fitLogisticModels, m_numFolds, m_randomSeed);
        }
    }
}

From source file:src.BigDataClassifier.GenerateFolds.java

public void generateFolds(Instances trainDataset) throws Exception {

    //randomize data
    Random rand = new Random(1);
    //set folds//from  w  ww .java 2s  .  c om
    int folds = 3;
    //create random dataset
    Instances randData = new Instances(trainDataset);
    randData.randomize(rand);

    Instances[] result = new Instances[folds * 2];
    //cross-validate
    for (int n = 0; n < folds; n++) {
        trainDataset = randData.trainCV(folds, n);
        System.out.println("Train dataset size is = " + trainDataset.size());
        Instances testDataset = randData.testCV(folds, n);
        System.out.println("Test dataset size is = " + testDataset.size());
        result[n] = trainDataset;
        result[n + 1] = testDataset;
        trainDataset2 = trainDataset;
        testDataset2 = testDataset;
    }
    trainDatasetSize = trainDataset2.size();
    testDatasetSize = testDataset2.size();
}

From source file:text_clustering.Cobweb.java

License:Open Source License

@Override
public void buildClusterer(Instances data) throws Exception {
    m_numberOfClusters = -1;//w  ww  . j  a va  2 s .c  om
    m_cobwebTree = null;
    m_numberSplits = 0;
    m_numberMerges = 0;

    // can clusterer handle the data?
    getCapabilities().testWithFail(data);

    // randomize the instances
    //data = new Instances(data);

    if (getSeed() >= 0) {
        data.randomize(new Random(getSeed()));
    }

    for (int i = 0; i < data.numInstances(); i++) {
        updateClusterer(data.instance(i));
        System.out.println((i * 100 / data.numInstances()) + 1 + "%");
        Menu_File.taOutput.appendText((i * 100 / data.numInstances()) + 1 + "%\n");
    }

    updateFinished();
}