List of usage examples for weka.attributeSelection AttributeSelection setFolds
public void setFolds(int folds)
From source file:adams.flow.transformer.WekaAttributeSelection.java
License:Open Source License
/** * Executes the flow item./*from ww w. j a v a 2 s. co m*/ * * @return null if everything is fine, otherwise error message */ @Override protected String doExecute() { String result; Instances data; Instances reduced; Instances transformed; AttributeSelection eval; boolean crossValidate; int fold; Instances train; WekaAttributeSelectionContainer cont; SpreadSheet stats; int i; Row row; int[] selected; double[][] ranked; Range range; String rangeStr; boolean useReduced; result = null; try { if (m_InputToken.getPayload() instanceof Instances) data = (Instances) m_InputToken.getPayload(); else data = (Instances) ((WekaTrainTestSetContainer) m_InputToken.getPayload()) .getValue(WekaTrainTestSetContainer.VALUE_TRAIN); if (result == null) { crossValidate = (m_Folds >= 2); // setup evaluation eval = new AttributeSelection(); eval.setEvaluator(m_Evaluator); eval.setSearch(m_Search); eval.setFolds(m_Folds); eval.setSeed((int) m_Seed); eval.setXval(crossValidate); // select attributes if (crossValidate) { Random random = new Random(m_Seed); data = new Instances(data); data.randomize(random); if ((data.classIndex() > -1) && data.classAttribute().isNominal()) { if (isLoggingEnabled()) getLogger().info("Stratifying instances..."); data.stratify(m_Folds); } for (fold = 0; fold < m_Folds; fold++) { if (isLoggingEnabled()) getLogger().info("Creating splits for fold " + (fold + 1) + "..."); train = data.trainCV(m_Folds, fold, random); if (isLoggingEnabled()) getLogger().info("Selecting attributes using all but fold " + (fold + 1) + "..."); eval.selectAttributesCVSplit(train); } } else { eval.SelectAttributes(data); } // generate reduced/transformed dataset reduced = null; transformed = null; if (!crossValidate) { reduced = eval.reduceDimensionality(data); if (m_Evaluator instanceof AttributeTransformer) transformed = ((AttributeTransformer) m_Evaluator).transformedData(data); } // generated stats stats = null; if (!crossValidate) { stats = new DefaultSpreadSheet(); row = stats.getHeaderRow(); useReduced = false; if (m_Search instanceof RankedOutputSearch) { i = reduced.numAttributes(); if (reduced.classIndex() > -1) i--; ranked = eval.rankedAttributes(); useReduced = (ranked.length == i); } if (useReduced) { for (i = 0; i < reduced.numAttributes(); i++) row.addCell("" + i).setContent(reduced.attribute(i).name()); row = stats.addRow(); for (i = 0; i < reduced.numAttributes(); i++) row.addCell(i).setContent(0.0); } else { for (i = 0; i < data.numAttributes(); i++) row.addCell("" + i).setContent(data.attribute(i).name()); row = stats.addRow(); for (i = 0; i < data.numAttributes(); i++) row.addCell(i).setContent(0.0); } if (m_Search instanceof RankedOutputSearch) { ranked = eval.rankedAttributes(); for (i = 0; i < ranked.length; i++) row.getCell((int) ranked[i][0]).setContent(ranked[i][1]); } else { selected = eval.selectedAttributes(); for (i = 0; i < selected.length; i++) row.getCell(selected[i]).setContent(1.0); } } // selected attributes rangeStr = null; if (!crossValidate) { range = new Range(); range.setIndices(eval.selectedAttributes()); rangeStr = range.getRange(); } // setup container if (crossValidate) cont = new WekaAttributeSelectionContainer(data, reduced, transformed, eval, m_Seed, m_Folds); else cont = new WekaAttributeSelectionContainer(data, reduced, transformed, eval, stats, rangeStr); m_OutputToken = new Token(cont); } } catch (Exception e) { m_OutputToken = null; result = handleException("Failed to process data:", e); } return result; }
From source file:miRdup.WekaModule.java
License:Open Source License
public static void attributeSelection(File arff, String outfile) { // load data/*from ww w . j a va 2 s.c o m*/ try { PrintWriter pw = new PrintWriter(new FileWriter(outfile)); DataSource source = new DataSource(arff.toString()); Instances data = source.getDataSet(); if (data.classIndex() == -1) { data.setClassIndex(data.numAttributes() - 1); } AttributeSelection attrsel = new AttributeSelection(); weka.attributeSelection.InfoGainAttributeEval eval = new weka.attributeSelection.InfoGainAttributeEval(); weka.attributeSelection.Ranker rank = new weka.attributeSelection.Ranker(); rank.setOptions(weka.core.Utils.splitOptions("-T -1.7976931348623157E308 -N -1")); if (Main.debug) { System.out.print("Model options: " + rank.getClass().getName().trim() + " "); } for (String s : rank.getOptions()) { System.out.print(s + " "); } attrsel.setEvaluator(eval); attrsel.setSearch(rank); attrsel.setFolds(10); attrsel.SelectAttributes(data); //attrsel.CrossValidateAttributes(); System.out.println(attrsel.toResultsString()); pw.println(attrsel.toResultsString()); //evaluation.crossValidateModel(classifier, data, 10, new Random(1)); pw.flush(); pw.close(); } catch (Exception e) { e.printStackTrace(); } }
From source file:old.CFS.java
/** * uses the low level approach/* w w w . j a v a 2s .com*/ * @param data */ protected static void useLowLevel(Instances data) throws Exception { System.out.println("\n3. Low-level"); AttributeSelection attsel = new AttributeSelection(); ChiSquaredAttributeEval eval = new ChiSquaredAttributeEval(); Ranker search = new Ranker(); search.setThreshold(-1.7976931348623157E308); search.setNumToSelect(1000); attsel.setEvaluator(eval); attsel.setSearch(search); attsel.setFolds(10); attsel.setXval(true); attsel.SelectAttributes(data); // System.out.println(data.toSummaryString()); // attsel.selectAttributesCVSplit(data); // attsel.SelectAttributes(data); System.out.println(attsel.CrossValidateAttributes()); // attsel.SelectAttributes(data); // attsel.selectAttributesCVSplit(data); Instances newData = attsel.reduceDimensionality(data); int[] indices = attsel.selectedAttributes(); System.out.println(newData); System.out.println("selected attribute indices (starting with 0):\n" + Utils.arrayToString(indices)); }