Example usage for weka.core Instances trainCV

List of usage examples for weka.core Instances trainCV

Introduction

In this page you can find the example usage for weka.core Instances trainCV.

Prototype



public Instances trainCV(int numFolds, int numFold) 

Source Link

Document

Creates the training set for one fold of a cross-validation on the dataset.

Usage

From source file:CrossValidationMultipleRuns.java

License:Open Source License

/**
 * Performs the cross-validation. See Javadoc of class for information
 * on command-line parameters./*w  w  w.  ja v  a 2 s . c o m*/
 *
 * @param args   the command-line parameters
 * @throws Exception   if something goes wrong
 */
public static void main(String[] args) throws Exception {
    // loads data and set class index
    Instances data = DataSource.read(Utils.getOption("t", args));
    String clsIndex = Utils.getOption("c", args);
    if (clsIndex.length() == 0)
        clsIndex = "last";
    if (clsIndex.equals("first"))
        data.setClassIndex(0);
    else if (clsIndex.equals("last"))
        data.setClassIndex(data.numAttributes() - 1);
    else
        data.setClassIndex(Integer.parseInt(clsIndex) - 1);

    // classifier
    String[] tmpOptions;
    String classname;
    tmpOptions = Utils.splitOptions(Utils.getOption("W", args));
    classname = tmpOptions[0];
    tmpOptions[0] = "";
    Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions);

    // other options
    int runs = Integer.parseInt(Utils.getOption("r", args));
    int folds = Integer.parseInt(Utils.getOption("x", args));

    // perform cross-validation
    for (int i = 0; i < runs; i++) {
        // randomize data
        int seed = i + 1;
        Random rand = new Random(seed);
        Instances randData = new Instances(data);
        randData.randomize(rand);
        //if (randData.classAttribute().isNominal())
        //   randData.stratify(folds);

        Evaluation eval = new Evaluation(randData);

        StringBuilder optionsString = new StringBuilder();
        for (String s : cls.getOptions()) {
            optionsString.append(s);
            optionsString.append(" ");
        }

        // output evaluation
        System.out.println();
        System.out.println("=== Setup run " + (i + 1) + " ===");
        System.out.println("Classifier: " + optionsString.toString());
        System.out.println("Dataset: " + data.relationName());
        System.out.println("Folds: " + folds);
        System.out.println("Seed: " + seed);
        System.out.println();

        for (int n = 0; n < folds; n++) {
            Instances train = randData.trainCV(folds, n);
            Instances test = randData.testCV(folds, n);

            // build and evaluate classifier
            Classifier clsCopy = Classifier.makeCopy(cls);
            clsCopy.buildClassifier(train);
            eval.evaluateModel(clsCopy, test);
            System.out.println(eval.toClassDetailsString());
        }

        System.out.println(
                eval.toSummaryString("=== " + folds + "-fold Cross-validation run " + (i + 1) + " ===", false));
    }
}

From source file:CopiaSeg3.java

public static Instances[] split(Instances data, int numberOfFolds) {
    Instances[] split = new Instances[2];

    Random semilla = new Random();
    int seed = semilla.nextInt(20); // Genera una semilla aleatorio entre 0 y 20
    Random rand = new Random(seed); // Create seeded number generator
    Instances randData = new Instances(data); // Crea una copia de los datos originales
    randData.randomize(rand); // Ordena los datos de forma aleatoria

    split[0] = randData.trainCV(numberOfFolds, 0);
    split[1] = randData.testCV(numberOfFolds, 0);

    return split;
}

From source file:algoritmogeneticocluster.NewClass.java

public static Instances[][] crossValidationSplit(Instances data, int numberOfFolds) {
    Instances[][] split = new Instances[2][numberOfFolds];

    for (int i = 0; i < numberOfFolds; i++) {
        split[0][i] = data.trainCV(numberOfFolds, i);
        split[1][i] = data.testCV(numberOfFolds, i);
    }/* www.  jav  a  2  s . c  o m*/

    return split;
}

From source file:asap.CrossValidation.java

/**
 *
 * @param dataInput/*  ww  w. ja v a 2  s. co m*/
 * @param classIndex
 * @param removeIndices
 * @param cls
 * @param seed
 * @param folds
 * @param modelOutputFile
 * @return
 * @throws Exception
 */
public static String performCrossValidation(String dataInput, String classIndex, String removeIndices,
        AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception {

    PerformanceCounters.startTimer("cross-validation ST");

    PerformanceCounters.startTimer("cross-validation init ST");

    // loads data and set class index
    Instances data = DataSource.read(dataInput);
    String clsIndex = classIndex;

    switch (clsIndex) {
    case "first":
        data.setClassIndex(0);
        break;
    case "last":
        data.setClassIndex(data.numAttributes() - 1);
        break;
    default:
        try {
            data.setClassIndex(Integer.parseInt(clsIndex) - 1);
        } catch (NumberFormatException e) {
            data.setClassIndex(data.attribute(clsIndex).index());
        }
        break;
    }

    Remove removeFilter = new Remove();
    removeFilter.setAttributeIndices(removeIndices);
    removeFilter.setInputFormat(data);
    data = Filter.useFilter(data, removeFilter);

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval = new Evaluation(randData);
    Instances trainSets[] = new Instances[folds];
    Instances testSets[] = new Instances[folds];
    Classifier foldCls[] = new Classifier[folds];

    for (int n = 0; n < folds; n++) {
        trainSets[n] = randData.trainCV(folds, n);
        testSets[n] = randData.testCV(folds, n);
        foldCls[n] = AbstractClassifier.makeCopy(cls);
    }

    PerformanceCounters.stopTimer("cross-validation init ST");
    PerformanceCounters.startTimer("cross-validation folds+train ST");
    //paralelize!!:--------------------------------------------------------------
    for (int n = 0; n < folds; n++) {
        Instances train = trainSets[n];
        Instances test = testSets[n];

        // the above code is used by the StratifiedRemoveFolds filter, the
        // code below by the Explorer/Experimenter:
        // Instances train = randData.trainCV(folds, n, rand);
        // build and evaluate classifier
        Classifier clsCopy = foldCls[n];
        clsCopy.buildClassifier(train);
        eval.evaluateModel(clsCopy, test);
    }

    cls.buildClassifier(data);
    //until here!-----------------------------------------------------------------

    PerformanceCounters.stopTimer("cross-validation folds+train ST");
    PerformanceCounters.startTimer("cross-validation post ST");
    // output evaluation
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (!modelOutputFile.isEmpty()) {
        SerializationHelper.write(modelOutputFile, cls);
    }

    PerformanceCounters.stopTimer("cross-validation post ST");
    PerformanceCounters.stopTimer("cross-validation ST");

    return out;
}

From source file:asap.CrossValidation.java

/**
 *
 * @param dataInput/*  www . ja v  a  2  s . c o m*/
 * @param classIndex
 * @param removeIndices
 * @param cls
 * @param seed
 * @param folds
 * @param modelOutputFile
 * @return
 * @throws Exception
 */
public static String performCrossValidationMT(String dataInput, String classIndex, String removeIndices,
        AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception {

    PerformanceCounters.startTimer("cross-validation MT");

    PerformanceCounters.startTimer("cross-validation init MT");

    // loads data and set class index
    Instances data = DataSource.read(dataInput);
    String clsIndex = classIndex;

    switch (clsIndex) {
    case "first":
        data.setClassIndex(0);
        break;
    case "last":
        data.setClassIndex(data.numAttributes() - 1);
        break;
    default:
        try {
            data.setClassIndex(Integer.parseInt(clsIndex) - 1);
        } catch (NumberFormatException e) {
            data.setClassIndex(data.attribute(clsIndex).index());
        }
        break;
    }

    Remove removeFilter = new Remove();
    removeFilter.setAttributeIndices(removeIndices);
    removeFilter.setInputFormat(data);
    data = Filter.useFilter(data, removeFilter);

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval = new Evaluation(randData);
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                AbstractClassifier.makeCopy(cls)));

        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init MT");
    PerformanceCounters.startTimer("cross-validation folds+train MT");
    //paralelize!!:--------------------------------------------------------------
    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        //use the current thread to run the cross-validation instead of using the Thread instance created here:
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    cls.buildClassifier(data);

    for (Thread foldThread : foldThreads) {
        foldThread.join();
    }

    //until here!-----------------------------------------------------------------
    PerformanceCounters.stopTimer("cross-validation folds+train MT");
    PerformanceCounters.startTimer("cross-validation post MT");
    // evaluation for output:
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (!modelOutputFile.isEmpty()) {
        SerializationHelper.write(modelOutputFile, cls);
    }

    PerformanceCounters.stopTimer("cross-validation post MT");
    PerformanceCounters.stopTimer("cross-validation MT");
    return out;
}

From source file:asap.CrossValidation.java

static String performCrossValidationMT(Instances data, AbstractClassifier cls, int seed, int folds,
        String modelOutputFile) {

    PerformanceCounters.startTimer("cross-validation MT");

    PerformanceCounters.startTimer("cross-validation init MT");

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);//from  w  w w . j  av  a2 s .  c  om
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval;
    try {
        eval = new Evaluation(randData);
    } catch (Exception ex) {
        Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        return "Error creating evaluation instance for given data!";
    }
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        try {
            foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                    AbstractClassifier.makeCopy(cls)));
        } catch (Exception ex) {
            Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        }

        //TODO: use Config.getNumThreads() for limiting these::
        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init MT");
    PerformanceCounters.startTimer("cross-validation folds+train MT");
    //paralelize!!:--------------------------------------------------------------
    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    try {
        cls.buildClassifier(data);
    } catch (Exception ex) {
        Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
    }

    for (Thread foldThread : foldThreads) {
        try {
            foldThread.join();
        } catch (InterruptedException ex) {
            Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    //until here!-----------------------------------------------------------------
    PerformanceCounters.stopTimer("cross-validation folds+train MT");
    PerformanceCounters.startTimer("cross-validation post MT");
    // evaluation for output:
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (modelOutputFile != null) {
        if (!modelOutputFile.isEmpty()) {
            try {
                SerializationHelper.write(modelOutputFile, cls);
            } catch (Exception ex) {
                Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    PerformanceCounters.stopTimer("cross-validation post MT");
    PerformanceCounters.stopTimer("cross-validation MT");
    return out;
}

From source file:asap.NLPSystem.java

private String crossValidate(int seed, int folds, String modelOutputFile) {

    PerformanceCounters.startTimer("cross-validation");
    PerformanceCounters.startTimer("cross-validation init");

    AbstractClassifier abstractClassifier = (AbstractClassifier) classifier;
    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(trainingSet);
    randData.randomize(rand);/*from  ww  w  .  ja  v  a 2s.  c om*/
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval;
    try {
        eval = new Evaluation(randData);
    } catch (Exception ex) {
        Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        return "Error creating evaluation instance for given data!";
    }
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        try {
            foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                    AbstractClassifier.makeCopy(abstractClassifier)));
        } catch (Exception ex) {
            Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        }

        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init");
    PerformanceCounters.startTimer("cross-validation folds+train");

    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    for (Thread foldThread : foldThreads) {
        while (foldThread.isAlive()) {
            try {
                foldThread.join();
            } catch (InterruptedException ex) {
                Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    PerformanceCounters.stopTimer("cross-validation folds+train");
    PerformanceCounters.startTimer("cross-validation post");
    // evaluation for output:
    String out = String.format(
            "\n=== Setup ===\nClassifier: %s %s\n" + "Dataset: %s\nFolds: %s\nSeed: %s\n\n%s\n",
            abstractClassifier.getClass().getName(), Utils.joinOptions(abstractClassifier.getOptions()),
            trainingSet.relationName(), folds, seed,
            eval.toSummaryString(String.format("=== %s-fold Cross-validation ===", folds), false));

    try {
        crossValidationPearsonsCorrelation = eval.correlationCoefficient();
    } catch (Exception ex) {
        Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
    }
    if (modelOutputFile != null) {
        if (!modelOutputFile.isEmpty()) {
            try {
                SerializationHelper.write(modelOutputFile, abstractClassifier);
            } catch (Exception ex) {
                Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    classifierBuiltWithCrossValidation = true;
    PerformanceCounters.stopTimer("cross-validation post");
    PerformanceCounters.stopTimer("cross-validation");
    return out;
}

From source file:au.edu.usyd.it.yangpy.sampling.BPSO.java

License:Open Source License

/**
 * this method starts the under sampling procedure
 *//*from   ww w  . j  a  va  2s  .  c o m*/
public void underSampling() {
    // create a copy of original data set for cross validation
    Instances randData = new Instances(dataset);

    // dividing the data set to 3 folds
    randData.stratify(3);

    for (int fold = 0; fold < 3; fold++) {
        // using the first 2 folds as internal training set. the last fold as the internal test set.
        internalTrain = randData.trainCV(3, fold);
        internalTest = randData.testCV(3, fold);

        // calculate the number of the major class samples in the internal training set
        majorSize = 0;

        for (int i = 0; i < internalTrain.numInstances(); i++) {
            if (internalTrain.instance(i).classValue() == majorLabel) {
                majorSize++;
            }
        }

        // class variable initialization
        dec = new DecimalFormat("##.####");
        localBest = new double[popSize];
        localBestParticles = new int[popSize][majorSize];
        globalBest = Double.MIN_VALUE;
        globalBestParticle = new int[majorSize];
        velocity = new double[popSize][majorSize];
        particles = new int[popSize][majorSize];
        searchSpace = new double[popSize][majorSize];

        System.out.println("-------------------- parameters ----------------------");
        System.out.println("CV fold = " + fold);
        System.out.println("inertia weight = " + w);
        System.out.println("c1,c2 = " + c1);
        System.out.println("iteration time = " + iteration);
        System.out.println("population size = " + popSize);

        // initialize BPSO
        initialization();

        // perform optimization process
        findMaxFit();

        // save optimization results to array list
        saveResults();
    }

    // rank the selected samples and build the balanced dataset
    try {
        createBalanceData();
    } catch (IOException ioe) {
        ioe.printStackTrace();
    }

}

From source file:au.edu.usyd.it.yangpy.snp.ParallelGenetic.java

License:Open Source License

public void crossValidate() {
    // create a copy of original training set for CV
    Instances randData = new Instances(data);

    // divide the data set with x-fold stratify measure
    randData.stratify(foldSize);/*  w ww. ja va2 s  .  com*/

    try {

        cvTrain = randData.trainCV(foldSize, foldIndex);
        cvTest = randData.testCV(foldSize, foldIndex);

        foldIndex++;

        if (foldIndex >= foldSize) {
            foldIndex = 0;
        }

    } catch (Exception e) {
        System.out.println(cvTest.toString());
    }
}

From source file:br.unicamp.ic.recod.gpsi.gp.gpsiJGAPRoiFitnessFunction.java

@Override
protected double evaluate(IGPProgram igpp) {

    double mean_accuracy = 0.0;
    Object[] noargs = new Object[0];

    gpsiRoiBandCombiner roiBandCombinator = new gpsiRoiBandCombiner(new gpsiJGAPVoxelCombiner(super.b, igpp));
    // TODO: The ROI descriptors must combine the images first
    //roiBandCombinator.combineEntity(this.dataset.getTrainingEntities());

    gpsiMLDataset mlDataset = new gpsiMLDataset(this.descriptor);
    try {/*from   w w  w.j  a v  a2s.co m*/
        mlDataset.loadWholeDataset(this.dataset, true);
    } catch (Exception ex) {
        Logger.getLogger(gpsiJGAPRoiFitnessFunction.class.getName()).log(Level.SEVERE, null, ex);
    }

    int dimensionality = mlDataset.getDimensionality();
    int n_classes = mlDataset.getTrainingEntities().keySet().size();
    int n_entities = mlDataset.getNumberOfTrainingEntities();
    ArrayList<Byte> listOfClasses = new ArrayList<>(mlDataset.getTrainingEntities().keySet());

    Attribute[] attributes = new Attribute[dimensionality];
    FastVector fvClassVal = new FastVector(n_classes);

    int i, j;
    for (i = 0; i < dimensionality; i++)
        attributes[i] = new Attribute("f" + Integer.toString(i));
    for (i = 0; i < n_classes; i++)
        fvClassVal.addElement(Integer.toString(listOfClasses.get(i)));

    Attribute classes = new Attribute("class", fvClassVal);

    FastVector fvWekaAttributes = new FastVector(dimensionality + 1);

    for (i = 0; i < dimensionality; i++)
        fvWekaAttributes.addElement(attributes[i]);
    fvWekaAttributes.addElement(classes);

    Instances instances = new Instances("Rel", fvWekaAttributes, n_entities);
    instances.setClassIndex(dimensionality);

    Instance iExample;
    for (byte label : mlDataset.getTrainingEntities().keySet()) {
        for (double[] featureVector : mlDataset.getTrainingEntities().get(label)) {
            iExample = new Instance(dimensionality + 1);
            for (j = 0; j < dimensionality; j++)
                iExample.setValue(i, featureVector[i]);
            iExample.setValue(dimensionality, label);
            instances.add(iExample);
        }
    }

    int folds = 5;
    Random rand = new Random();
    Instances randData = new Instances(instances);
    randData.randomize(rand);

    Instances trainingSet, testingSet;
    Classifier cModel;
    Evaluation eTest;
    try {
        for (i = 0; i < folds; i++) {
            cModel = (Classifier) new SimpleLogistic();
            trainingSet = randData.trainCV(folds, i);
            testingSet = randData.testCV(folds, i);

            cModel.buildClassifier(trainingSet);

            eTest = new Evaluation(trainingSet);
            eTest.evaluateModel(cModel, testingSet);

            mean_accuracy += eTest.pctCorrect();

        }
    } catch (Exception ex) {
        Logger.getLogger(gpsiJGAPRoiFitnessFunction.class.getName()).log(Level.SEVERE, null, ex);
    }

    mean_accuracy /= (folds * 100);

    return mean_accuracy;

}