Example usage for weka.classifiers.trees.j48 C45PruneableClassifierTree buildClassifier

List of usage examples for weka.classifiers.trees.j48 C45PruneableClassifierTree buildClassifier

Introduction

In this page you can find the example usage for weka.classifiers.trees.j48 C45PruneableClassifierTree buildClassifier.

Prototype

public void buildClassifier(Instances data) throws Exception 

Source Link

Document

Method for building a pruneable classifier tree.

Usage

From source file:marytts.tools.newlanguage.LTSTrainer.java

License:Open Source License

/**
 * Train the tree, using binary decision nodes.
 * /*from  w w  w  . java  2 s. com*/
 * @param minLeafData
 *            the minimum number of instances that have to occur in at least two subsets induced by split
 * @return bigTree
 * @throws IOException
 *             IOException
 */
public CART trainTree(int minLeafData) throws IOException {

    Map<String, List<String[]>> grapheme2align = new HashMap<String, List<String[]>>();
    for (String gr : this.graphemeSet) {
        grapheme2align.put(gr, new ArrayList<String[]>());
    }

    Set<String> phChains = new HashSet<String>();

    // for every alignment pair collect counts
    for (int i = 0; i < this.inSplit.size(); i++) {

        StringPair[] alignment = this.getAlignment(i);

        for (int inNr = 0; inNr < alignment.length; inNr++) {

            // System.err.println(alignment[inNr]);

            // quotation signs needed to represent empty string
            String outAlNr = "'" + alignment[inNr].getString2() + "'";

            // TODO: don't consider alignments to more than three characters
            if (outAlNr.length() > 5)
                continue;

            phChains.add(outAlNr);

            // storing context and target
            String[] datapoint = new String[2 * context + 2];

            for (int ct = 0; ct < 2 * context + 1; ct++) {
                int pos = inNr - context + ct;

                if (pos >= 0 && pos < alignment.length) {
                    datapoint[ct] = alignment[pos].getString1();
                } else {
                    datapoint[ct] = "null";
                }

            }

            // set target
            datapoint[2 * context + 1] = outAlNr;

            // add datapoint
            grapheme2align.get(alignment[inNr].getString1()).add(datapoint);
        }
    }

    // for conversion need feature definition file
    FeatureDefinition fd = this.graphemeFeatureDef(phChains);

    int centerGrapheme = fd.getFeatureIndex("att" + (context + 1));

    List<CART> stl = new ArrayList<CART>(fd.getNumberOfValues(centerGrapheme));

    for (String gr : fd.getPossibleValues(centerGrapheme)) {
        System.out.println("      Training decision tree for: " + gr);
        logger.debug("      Training decision tree for: " + gr);

        ArrayList<Attribute> attributeDeclarations = new ArrayList<Attribute>();

        // attributes with values
        for (int att = 1; att <= context * 2 + 1; att++) {

            // ...collect possible values
            ArrayList<String> attVals = new ArrayList<String>();

            String featureName = "att" + att;

            for (String usableGrapheme : fd.getPossibleValues(fd.getFeatureIndex(featureName))) {
                attVals.add(usableGrapheme);
            }

            attributeDeclarations.add(new Attribute(featureName, attVals));
        }

        List<String[]> datapoints = grapheme2align.get(gr);

        // maybe training is faster with targets limited to grapheme
        Set<String> graphSpecPh = new HashSet<String>();
        for (String[] dp : datapoints) {
            graphSpecPh.add(dp[dp.length - 1]);
        }

        // targetattribute
        // ...collect possible values
        ArrayList<String> targetVals = new ArrayList<String>();
        for (String phc : graphSpecPh) {// todo: use either fd of phChains
            targetVals.add(phc);
        }
        attributeDeclarations.add(new Attribute(TrainedLTS.PREDICTED_STRING_FEATURENAME, targetVals));

        // now, create the dataset adding the datapoints
        Instances data = new Instances(gr, attributeDeclarations, 0);

        // datapoints
        for (String[] point : datapoints) {

            Instance currInst = new DenseInstance(data.numAttributes());
            currInst.setDataset(data);

            for (int i = 0; i < point.length; i++) {

                currInst.setValue(i, point[i]);
            }

            data.add(currInst);
        }

        // Make the last attribute be the class
        data.setClassIndex(data.numAttributes() - 1);

        // build the tree without using the J48 wrapper class
        // standard parameters are:
        // binary split selection with minimum x instances at the leaves, tree is pruned, confidenced value, subtree raising,
        // cleanup, don't collapse
        // Here is used a modifed version of C45PruneableClassifierTree that allow using Unary Classes (see Issue #51)
        C45PruneableClassifierTree decisionTree;
        try {
            decisionTree = new C45PruneableClassifierTreeWithUnary(
                    new BinC45ModelSelection(minLeafData, data, true), true, 0.25f, true, true, false);
            decisionTree.buildClassifier(data);
        } catch (Exception e) {
            throw new RuntimeException("couldn't train decisiontree using weka: ", e);
        }

        CART maryTree = TreeConverter.c45toStringCART(decisionTree, fd, data);

        stl.add(maryTree);
    }

    DecisionNode.ByteDecisionNode rootNode = new DecisionNode.ByteDecisionNode(centerGrapheme, stl.size(), fd);
    for (CART st : stl) {
        rootNode.addDaughter(st.getRootNode());
    }

    Properties props = new Properties();
    props.setProperty("lowercase", String.valueOf(convertToLowercase));
    props.setProperty("stress", String.valueOf(considerStress));
    props.setProperty("context", String.valueOf(context));

    CART bigTree = new CART(rootNode, fd, props);

    return bigTree;
}

From source file:marytts.tools.voiceimport.PauseDurationTrainer.java

License:Open Source License

private StringPredictionTree trainTree(Instances data, FeatureDefinition fd) throws Exception {

    System.out.println("training duration tree (" + data.numInstances() + " instances) ...");

    // build the tree without using the J48 wrapper class
    // standard parameters are:
    // binary split selection with minimum x instances at the leaves, tree is pruned, confidence value, subtree raising,
    // cleanup, don't collapse
    C45PruneableClassifierTree decisionTree = new C45PruneableClassifierTree(
            new BinC45ModelSelection(2, data, true), true, 0.25f, true, true, false);

    decisionTree.buildClassifier(data);

    System.out.println("...done");

    return TreeConverter.c45toStringPredictionTree(decisionTree, fd, data);
}