List of usage examples for weka.classifiers.trees.j48 C45PruneableClassifierTree buildClassifier
public void buildClassifier(Instances data) throws Exception
From source file:marytts.tools.newlanguage.LTSTrainer.java
License:Open Source License
/** * Train the tree, using binary decision nodes. * /*from w w w . java 2 s. com*/ * @param minLeafData * the minimum number of instances that have to occur in at least two subsets induced by split * @return bigTree * @throws IOException * IOException */ public CART trainTree(int minLeafData) throws IOException { Map<String, List<String[]>> grapheme2align = new HashMap<String, List<String[]>>(); for (String gr : this.graphemeSet) { grapheme2align.put(gr, new ArrayList<String[]>()); } Set<String> phChains = new HashSet<String>(); // for every alignment pair collect counts for (int i = 0; i < this.inSplit.size(); i++) { StringPair[] alignment = this.getAlignment(i); for (int inNr = 0; inNr < alignment.length; inNr++) { // System.err.println(alignment[inNr]); // quotation signs needed to represent empty string String outAlNr = "'" + alignment[inNr].getString2() + "'"; // TODO: don't consider alignments to more than three characters if (outAlNr.length() > 5) continue; phChains.add(outAlNr); // storing context and target String[] datapoint = new String[2 * context + 2]; for (int ct = 0; ct < 2 * context + 1; ct++) { int pos = inNr - context + ct; if (pos >= 0 && pos < alignment.length) { datapoint[ct] = alignment[pos].getString1(); } else { datapoint[ct] = "null"; } } // set target datapoint[2 * context + 1] = outAlNr; // add datapoint grapheme2align.get(alignment[inNr].getString1()).add(datapoint); } } // for conversion need feature definition file FeatureDefinition fd = this.graphemeFeatureDef(phChains); int centerGrapheme = fd.getFeatureIndex("att" + (context + 1)); List<CART> stl = new ArrayList<CART>(fd.getNumberOfValues(centerGrapheme)); for (String gr : fd.getPossibleValues(centerGrapheme)) { System.out.println(" Training decision tree for: " + gr); logger.debug(" Training decision tree for: " + gr); ArrayList<Attribute> attributeDeclarations = new ArrayList<Attribute>(); // attributes with values for (int att = 1; att <= context * 2 + 1; att++) { // ...collect possible values ArrayList<String> attVals = new ArrayList<String>(); String featureName = "att" + att; for (String usableGrapheme : fd.getPossibleValues(fd.getFeatureIndex(featureName))) { attVals.add(usableGrapheme); } attributeDeclarations.add(new Attribute(featureName, attVals)); } List<String[]> datapoints = grapheme2align.get(gr); // maybe training is faster with targets limited to grapheme Set<String> graphSpecPh = new HashSet<String>(); for (String[] dp : datapoints) { graphSpecPh.add(dp[dp.length - 1]); } // targetattribute // ...collect possible values ArrayList<String> targetVals = new ArrayList<String>(); for (String phc : graphSpecPh) {// todo: use either fd of phChains targetVals.add(phc); } attributeDeclarations.add(new Attribute(TrainedLTS.PREDICTED_STRING_FEATURENAME, targetVals)); // now, create the dataset adding the datapoints Instances data = new Instances(gr, attributeDeclarations, 0); // datapoints for (String[] point : datapoints) { Instance currInst = new DenseInstance(data.numAttributes()); currInst.setDataset(data); for (int i = 0; i < point.length; i++) { currInst.setValue(i, point[i]); } data.add(currInst); } // Make the last attribute be the class data.setClassIndex(data.numAttributes() - 1); // build the tree without using the J48 wrapper class // standard parameters are: // binary split selection with minimum x instances at the leaves, tree is pruned, confidenced value, subtree raising, // cleanup, don't collapse // Here is used a modifed version of C45PruneableClassifierTree that allow using Unary Classes (see Issue #51) C45PruneableClassifierTree decisionTree; try { decisionTree = new C45PruneableClassifierTreeWithUnary( new BinC45ModelSelection(minLeafData, data, true), true, 0.25f, true, true, false); decisionTree.buildClassifier(data); } catch (Exception e) { throw new RuntimeException("couldn't train decisiontree using weka: ", e); } CART maryTree = TreeConverter.c45toStringCART(decisionTree, fd, data); stl.add(maryTree); } DecisionNode.ByteDecisionNode rootNode = new DecisionNode.ByteDecisionNode(centerGrapheme, stl.size(), fd); for (CART st : stl) { rootNode.addDaughter(st.getRootNode()); } Properties props = new Properties(); props.setProperty("lowercase", String.valueOf(convertToLowercase)); props.setProperty("stress", String.valueOf(considerStress)); props.setProperty("context", String.valueOf(context)); CART bigTree = new CART(rootNode, fd, props); return bigTree; }
From source file:marytts.tools.voiceimport.PauseDurationTrainer.java
License:Open Source License
private StringPredictionTree trainTree(Instances data, FeatureDefinition fd) throws Exception { System.out.println("training duration tree (" + data.numInstances() + " instances) ..."); // build the tree without using the J48 wrapper class // standard parameters are: // binary split selection with minimum x instances at the leaves, tree is pruned, confidence value, subtree raising, // cleanup, don't collapse C45PruneableClassifierTree decisionTree = new C45PruneableClassifierTree( new BinC45ModelSelection(2, data, true), true, 0.25f, true, true, false); decisionTree.buildClassifier(data); System.out.println("...done"); return TreeConverter.c45toStringPredictionTree(decisionTree, fd, data); }