Example usage for weka.attributeSelection AttributeSelection setFolds

List of usage examples for weka.attributeSelection AttributeSelection setFolds

Introduction

In this page you can find the example usage for weka.attributeSelection AttributeSelection setFolds.

Prototype

public void setFolds(int folds) 

Source Link

Document

set the number of folds for cross validation

Usage

From source file:adams.flow.transformer.WekaAttributeSelection.java

License:Open Source License

/**
 * Executes the flow item./*from  ww w. j a  v  a  2  s. co  m*/
 *
 * @return      null if everything is fine, otherwise error message
 */
@Override
protected String doExecute() {
    String result;
    Instances data;
    Instances reduced;
    Instances transformed;
    AttributeSelection eval;
    boolean crossValidate;
    int fold;
    Instances train;
    WekaAttributeSelectionContainer cont;
    SpreadSheet stats;
    int i;
    Row row;
    int[] selected;
    double[][] ranked;
    Range range;
    String rangeStr;
    boolean useReduced;

    result = null;

    try {
        if (m_InputToken.getPayload() instanceof Instances)
            data = (Instances) m_InputToken.getPayload();
        else
            data = (Instances) ((WekaTrainTestSetContainer) m_InputToken.getPayload())
                    .getValue(WekaTrainTestSetContainer.VALUE_TRAIN);

        if (result == null) {
            crossValidate = (m_Folds >= 2);

            // setup evaluation
            eval = new AttributeSelection();
            eval.setEvaluator(m_Evaluator);
            eval.setSearch(m_Search);
            eval.setFolds(m_Folds);
            eval.setSeed((int) m_Seed);
            eval.setXval(crossValidate);

            // select attributes
            if (crossValidate) {
                Random random = new Random(m_Seed);
                data = new Instances(data);
                data.randomize(random);
                if ((data.classIndex() > -1) && data.classAttribute().isNominal()) {
                    if (isLoggingEnabled())
                        getLogger().info("Stratifying instances...");
                    data.stratify(m_Folds);
                }
                for (fold = 0; fold < m_Folds; fold++) {
                    if (isLoggingEnabled())
                        getLogger().info("Creating splits for fold " + (fold + 1) + "...");
                    train = data.trainCV(m_Folds, fold, random);
                    if (isLoggingEnabled())
                        getLogger().info("Selecting attributes using all but fold " + (fold + 1) + "...");
                    eval.selectAttributesCVSplit(train);
                }
            } else {
                eval.SelectAttributes(data);
            }

            // generate reduced/transformed dataset
            reduced = null;
            transformed = null;
            if (!crossValidate) {
                reduced = eval.reduceDimensionality(data);
                if (m_Evaluator instanceof AttributeTransformer)
                    transformed = ((AttributeTransformer) m_Evaluator).transformedData(data);
            }

            // generated stats
            stats = null;
            if (!crossValidate) {
                stats = new DefaultSpreadSheet();
                row = stats.getHeaderRow();

                useReduced = false;
                if (m_Search instanceof RankedOutputSearch) {
                    i = reduced.numAttributes();
                    if (reduced.classIndex() > -1)
                        i--;
                    ranked = eval.rankedAttributes();
                    useReduced = (ranked.length == i);
                }

                if (useReduced) {
                    for (i = 0; i < reduced.numAttributes(); i++)
                        row.addCell("" + i).setContent(reduced.attribute(i).name());
                    row = stats.addRow();
                    for (i = 0; i < reduced.numAttributes(); i++)
                        row.addCell(i).setContent(0.0);
                } else {
                    for (i = 0; i < data.numAttributes(); i++)
                        row.addCell("" + i).setContent(data.attribute(i).name());
                    row = stats.addRow();
                    for (i = 0; i < data.numAttributes(); i++)
                        row.addCell(i).setContent(0.0);
                }

                if (m_Search instanceof RankedOutputSearch) {
                    ranked = eval.rankedAttributes();
                    for (i = 0; i < ranked.length; i++)
                        row.getCell((int) ranked[i][0]).setContent(ranked[i][1]);
                } else {
                    selected = eval.selectedAttributes();
                    for (i = 0; i < selected.length; i++)
                        row.getCell(selected[i]).setContent(1.0);
                }
            }

            // selected attributes
            rangeStr = null;
            if (!crossValidate) {
                range = new Range();
                range.setIndices(eval.selectedAttributes());
                rangeStr = range.getRange();
            }

            // setup container
            if (crossValidate)
                cont = new WekaAttributeSelectionContainer(data, reduced, transformed, eval, m_Seed, m_Folds);
            else
                cont = new WekaAttributeSelectionContainer(data, reduced, transformed, eval, stats, rangeStr);
            m_OutputToken = new Token(cont);
        }
    } catch (Exception e) {
        m_OutputToken = null;
        result = handleException("Failed to process data:", e);
    }

    return result;
}

From source file:miRdup.WekaModule.java

License:Open Source License

public static void attributeSelection(File arff, String outfile) {
    // load data/*from ww w . j a va 2 s.c  o m*/
    try {
        PrintWriter pw = new PrintWriter(new FileWriter(outfile));
        DataSource source = new DataSource(arff.toString());
        Instances data = source.getDataSet();
        if (data.classIndex() == -1) {
            data.setClassIndex(data.numAttributes() - 1);
        }

        AttributeSelection attrsel = new AttributeSelection();
        weka.attributeSelection.InfoGainAttributeEval eval = new weka.attributeSelection.InfoGainAttributeEval();

        weka.attributeSelection.Ranker rank = new weka.attributeSelection.Ranker();
        rank.setOptions(weka.core.Utils.splitOptions("-T -1.7976931348623157E308 -N -1"));
        if (Main.debug) {
            System.out.print("Model options: " + rank.getClass().getName().trim() + " ");
        }
        for (String s : rank.getOptions()) {
            System.out.print(s + " ");
        }
        attrsel.setEvaluator(eval);
        attrsel.setSearch(rank);
        attrsel.setFolds(10);

        attrsel.SelectAttributes(data);
        //attrsel.CrossValidateAttributes();

        System.out.println(attrsel.toResultsString());
        pw.println(attrsel.toResultsString());

        //evaluation.crossValidateModel(classifier, data, 10, new Random(1));
        pw.flush();
        pw.close();

    } catch (Exception e) {
        e.printStackTrace();
    }
}

From source file:old.CFS.java

/**
 * uses the low level approach/* w  w  w  .  j  a  v a  2s .com*/
   * @param data
 */
protected static void useLowLevel(Instances data) throws Exception {
    System.out.println("\n3. Low-level");
    AttributeSelection attsel = new AttributeSelection();
    ChiSquaredAttributeEval eval = new ChiSquaredAttributeEval();
    Ranker search = new Ranker();
    search.setThreshold(-1.7976931348623157E308);
    search.setNumToSelect(1000);
    attsel.setEvaluator(eval);
    attsel.setSearch(search);
    attsel.setFolds(10);
    attsel.setXval(true);
    attsel.SelectAttributes(data);
    //    System.out.println(data.toSummaryString());
    //    attsel.selectAttributesCVSplit(data);
    //    attsel.SelectAttributes(data);

    System.out.println(attsel.CrossValidateAttributes());
    //    attsel.SelectAttributes(data);
    //    attsel.selectAttributesCVSplit(data);
    Instances newData = attsel.reduceDimensionality(data);

    int[] indices = attsel.selectedAttributes();
    System.out.println(newData);
    System.out.println("selected attribute indices (starting with 0):\n" + Utils.arrayToString(indices));
}