sklearn.tree.TreeModelUtil.java Source code

Java tutorial

Introduction

Here is the source code for sklearn.tree.TreeModelUtil.java

Source

/*
 * Copyright (c) 2015 Villu Ruusmann
 *
 * This file is part of JPMML-SkLearn
 *
 * JPMML-SkLearn is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-SkLearn is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-SkLearn.  If not, see <http://www.gnu.org/licenses/>.
 */
package sklearn.tree;

import java.util.ArrayList;
import java.util.List;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import sklearn.Estimator;

public class TreeModelUtil {

    private TreeModelUtil() {
    }

    static public <E extends Estimator & HasTree> List<TreeModel> encodeTreeModelSegmentation(List<E> estimators,
            final MiningFunction miningFunction, final Schema schema) {
        Function<E, TreeModel> function = new Function<E, TreeModel>() {

            private Schema segmentSchema = schema.toAnonymousSchema();

            @Override
            public TreeModel apply(E estimator) {
                return TreeModelUtil.encodeTreeModel(estimator, miningFunction, this.segmentSchema);
            }
        };

        return new ArrayList<>(Lists.transform(estimators, function));
    }

    static public <E extends Estimator & HasTree> TreeModel encodeTreeModel(E estimator,
            MiningFunction miningFunction, Schema schema) {
        Tree tree = estimator.getTree();

        int[] leftChildren = tree.getChildrenLeft();
        int[] rightChildren = tree.getChildrenRight();
        int[] features = tree.getFeature();
        double[] thresholds = tree.getThreshold();
        double[] values = tree.getValues();

        Node root = new Node().setId("1").setPredicate(new True());

        encodeNode(root, 0, leftChildren, rightChildren, features, thresholds, values, miningFunction, schema);

        TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema), root)
                .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);

        return treeModel;
    }

    static private void encodeNode(Node node, int index, int[] leftChildren, int[] rightChildren, int[] features,
            double[] thresholds, double[] values, MiningFunction miningFunction, Schema schema) {
        int featureIndex = features[index];

        // A non-leaf (binary split) node
        if (featureIndex >= 0) {
            Feature feature = schema.getFeature(featureIndex);

            float threshold = (float) thresholds[index];

            Predicate leftPredicate;
            Predicate rightPredicate;

            if (feature instanceof ContinuousFeature) {
                ContinuousFeature continuousFeature = (ContinuousFeature) feature;

                String value = ValueUtil.formatValue(threshold);

                leftPredicate = new SimplePredicate(continuousFeature.getName(),
                        SimplePredicate.Operator.LESS_OR_EQUAL).setValue(value);

                rightPredicate = new SimplePredicate(continuousFeature.getName(),
                        SimplePredicate.Operator.GREATER_THAN).setValue(value);
            } else

            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature) feature;

                if (threshold < 0 || threshold > 1) {
                    throw new IllegalArgumentException();
                }

                leftPredicate = new SimplePredicate(binaryFeature.getName(), SimplePredicate.Operator.NOT_EQUAL)
                        .setValue(binaryFeature.getValue());

                rightPredicate = new SimplePredicate(binaryFeature.getName(), SimplePredicate.Operator.EQUAL)
                        .setValue(binaryFeature.getValue());
            } else

            {
                throw new IllegalArgumentException();
            }

            int leftIndex = leftChildren[index];
            int rightIndex = rightChildren[index];

            Node leftChild = new Node().setId(String.valueOf(leftIndex + 1)).setPredicate(leftPredicate);

            encodeNode(leftChild, leftIndex, leftChildren, rightChildren, features, thresholds, values,
                    miningFunction, schema);

            Node rightChild = new Node().setId(String.valueOf(rightIndex + 1)).setPredicate(rightPredicate);

            encodeNode(rightChild, rightIndex, leftChildren, rightChildren, features, thresholds, values,
                    miningFunction, schema);

            node.addNodes(leftChild, rightChild);
        } else

        // A leaf node
        {
            if ((MiningFunction.CLASSIFICATION).equals(miningFunction)) {
                List<String> targetCategories = schema.getTargetCategories();

                double[] scoreRecordCounts = getRow(values, leftChildren.length, targetCategories.size(), index);

                double recordCount = 0;

                for (double scoreRecordCount : scoreRecordCounts) {
                    recordCount += scoreRecordCount;
                }

                node.setRecordCount(recordCount);

                String score = null;

                Double probability = null;

                for (int i = 0; i < targetCategories.size(); i++) {
                    String targetCategory = targetCategories.get(i);

                    ScoreDistribution scoreDistribution = new ScoreDistribution(targetCategory,
                            scoreRecordCounts[i]);

                    node.addScoreDistributions(scoreDistribution);

                    double scoreProbability = (scoreRecordCounts[i] / recordCount);

                    if (probability == null || probability.compareTo(scoreProbability) < 0) {
                        score = scoreDistribution.getValue();

                        probability = scoreProbability;
                    }
                }

                node.setScore(score);
            } else

            if ((MiningFunction.REGRESSION).equals(miningFunction)) {
                String score = ValueUtil.formatValue(values[index]);

                node.setScore(score);
            } else

            {
                throw new IllegalArgumentException();
            }
        }
    }

    static private double[] getRow(double[] values, int rows, int columns, int row) {

        if (values.length != (rows * columns)) {
            throw new IllegalArgumentException();
        }

        double[] result = new double[columns];

        System.arraycopy(values, (row * columns), result, 0, columns);

        return result;
    }
}