org.jpmml.rexp.RandomForestConverter.java Source code

Java tutorial

Introduction

Here is the source code for org.jpmml.rexp.RandomForestConverter.java

Source

/*
 * Copyright (c) 2014 Villu Ruusmann
 *
 * This file is part of JPMML-R
 *
 * JPMML-R is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-R is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-R.  If not, see <http://www.gnu.org/licenses/>.
 */
package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.List;

import com.google.common.math.DoubleMath;
import com.google.common.primitives.UnsignedLong;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.MultipleModelMethodType;
import org.dmg.pmml.Node;
import org.dmg.pmml.Output;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.Segmentation;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.TreeModel;
import org.dmg.pmml.True;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ListFeature;
import org.jpmml.converter.MiningModelUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;

public class RandomForestConverter extends TreeModelConverter<RGenericVector> {

    public RandomForestConverter(RGenericVector randomForest) {
        super(randomForest);
    }

    @Override
    public void encodeFeatures(FeatureMapper featureMapper) {
        RGenericVector randomForest = getObject();

        RGenericVector forest = (RGenericVector) randomForest.getValue("forest");

        RNumberVector<?> y;

        try {
            y = (RNumberVector<?>) randomForest.getValue("y");
        } catch (IllegalArgumentException iae) {
            y = null;
        }

        RNumberVector<?> ncat = (RNumberVector<?>) forest.getValue("ncat");
        RGenericVector xlevels = (RGenericVector) forest.getValue("xlevels");

        try {
            RExp terms = randomForest.getValue("terms");

            // The RF model was trained using the formula interface
            encodeFormula(terms, y, xlevels, ncat, featureMapper);
        } catch (IllegalArgumentException iae) {
            RStringVector xNames;

            try {
                xNames = (RStringVector) randomForest.getValue("xNames");
            } catch (IllegalArgumentException iaeChild) {
                xNames = xlevels.names();
            }

            // The RF model was trained using the matrix (ie. non-formula) interface
            encodeNonFormula(xNames, y, xlevels, ncat, featureMapper);
        }
    }

    @Override
    public MiningModel encodeModel(Schema schema) {
        RGenericVector randomForest = getObject();

        RStringVector type = (RStringVector) randomForest.getValue("type");
        RGenericVector forest = (RGenericVector) randomForest.getValue("forest");

        switch (type.asScalar()) {
        case "regression":
            return encodeRegression(forest, schema);
        case "classification":
            return encodeClassification(forest, schema);
        default:
            throw new IllegalArgumentException();
        }
    }

    private void encodeFormula(RExp terms, RNumberVector<?> y, RGenericVector xlevels, RNumberVector<?> ncat,
            FeatureMapper featureMapper) {
        RStringVector dataClasses = (RStringVector) terms.getAttributeValue("dataClasses");

        RStringVector dataClassNames = dataClasses.names();

        // Dependent variable
        {
            FieldName name = FieldName.create(dataClassNames.getValue(0));
            DataType dataType = RExpUtil.getDataType(dataClasses.getValue(0));

            if (y instanceof RIntegerVector) {
                RIntegerVector factor = (RIntegerVector) y;

                featureMapper.append(name, dataType, factor.getLevelValues());
            } else

            {
                featureMapper.append(name, dataType);
            }
        }

        RStringVector xlevelNames = xlevels.names();

        // Independent variables
        for (int i = 0; i < ncat.size(); i++) {
            int index = (dataClassNames.getValues()).indexOf(xlevelNames.getValue(i));
            if (index < 1) {
                throw new IllegalArgumentException();
            }

            FieldName name = FieldName.create(dataClassNames.getValue(index));
            DataType dataType = RExpUtil.getDataType(dataClasses.getValue(index));

            boolean categorical = ((ncat.getValue(i)).doubleValue() > 1d);
            if (categorical) {
                RStringVector levels = (RStringVector) xlevels.getValue(i);

                featureMapper.append(name, dataType, levels.getValues());
            } else

            {
                featureMapper.append(name, dataType);
            }
        }
    }

    private void encodeNonFormula(RStringVector xNames, RNumberVector<?> y, RGenericVector xlevels,
            RNumberVector<?> ncat, FeatureMapper featureMapper) {

        // Dependent variable
        {
            FieldName name = FieldName.create("_target");

            if (y instanceof RIntegerVector) {
                RIntegerVector factor = (RIntegerVector) y;

                featureMapper.append(name, factor.getLevelValues());
            } else

            {
                featureMapper.append(name, false);
            }
        }

        // Independernt variables
        for (int i = 0; i < ncat.size(); i++) {
            FieldName name = FieldName.create(xNames.getValue(i));

            boolean categorical = ((ncat.getValue(i)).doubleValue() > 1d);
            if (categorical) {
                RStringVector levels = (RStringVector) xlevels.getValue(i);

                featureMapper.append(name, levels.getValues());
            } else

            {
                featureMapper.append(name, false);
            }
        }
    }

    private MiningModel encodeRegression(RGenericVector forest, final Schema schema) {
        RNumberVector<?> leftDaughter = (RNumberVector<?>) forest.getValue("leftDaughter");
        RNumberVector<?> rightDaughter = (RNumberVector<?>) forest.getValue("rightDaughter");
        RDoubleVector nodepred = (RDoubleVector) forest.getValue("nodepred");
        RNumberVector<?> bestvar = (RNumberVector<?>) forest.getValue("bestvar");
        RDoubleVector xbestsplit = (RDoubleVector) forest.getValue("xbestsplit");
        RIntegerVector nrnodes = (RIntegerVector) forest.getValue("nrnodes");
        RDoubleVector ntree = (RDoubleVector) forest.getValue("ntree");

        ScoreEncoder<Double> scoreEncoder = new ScoreEncoder<Double>() {

            @Override
            public String encode(Double value) {
                return ValueUtil.formatValue(value);
            }
        };

        int rows = nrnodes.asScalar();
        int columns = ValueUtil.asInt(ntree.asScalar());

        Schema segmentSchema = schema.toAnonymousSchema();

        List<TreeModel> treeModels = new ArrayList<>();

        for (int i = 0; i < columns; i++) {
            TreeModel treeModel = encodeTreeModel(MiningFunctionType.REGRESSION, scoreEncoder,
                    RExpUtil.getColumn(leftDaughter.getValues(), rows, columns, i),
                    RExpUtil.getColumn(rightDaughter.getValues(), rows, columns, i),
                    RExpUtil.getColumn(nodepred.getValues(), rows, columns, i),
                    RExpUtil.getColumn(bestvar.getValues(), rows, columns, i),
                    RExpUtil.getColumn(xbestsplit.getValues(), rows, columns, i), segmentSchema);

            treeModels.add(treeModel);
        }

        Segmentation segmentation = MiningModelUtil.createSegmentation(MultipleModelMethodType.AVERAGE, treeModels);

        MiningSchema miningSchema = ModelUtil.createMiningSchema(schema);

        MiningModel miningModel = new MiningModel(MiningFunctionType.REGRESSION, miningSchema)
                .setSegmentation(segmentation);

        return miningModel;
    }

    private MiningModel encodeClassification(RGenericVector forest, final Schema schema) {
        RNumberVector<?> bestvar = (RNumberVector<?>) forest.getValue("bestvar");
        RNumberVector<?> treemap = (RNumberVector<?>) forest.getValue("treemap");
        RIntegerVector nodepred = (RIntegerVector) forest.getValue("nodepred");
        RDoubleVector xbestsplit = (RDoubleVector) forest.getValue("xbestsplit");
        RIntegerVector nrnodes = (RIntegerVector) forest.getValue("nrnodes");
        RDoubleVector ntree = (RDoubleVector) forest.getValue("ntree");

        ScoreEncoder<Integer> scoreEncoder = new ScoreEncoder<Integer>() {

            private List<String> targetCategories = schema.getTargetCategories();

            @Override
            public String encode(Integer value) {
                return this.targetCategories.get(value - 1);
            }
        };

        int rows = nrnodes.asScalar();
        int columns = ValueUtil.asInt(ntree.asScalar());

        Schema segmentSchema = schema.toAnonymousSchema();

        List<TreeModel> treeModels = new ArrayList<>();

        for (int i = 0; i < columns; i++) {
            List<? extends Number> daughters = RExpUtil.getColumn(treemap.getValues(), 2 * rows, columns, i);

            TreeModel treeModel = encodeTreeModel(MiningFunctionType.CLASSIFICATION, scoreEncoder,
                    RExpUtil.getColumn(daughters, rows, columns, 0),
                    RExpUtil.getColumn(daughters, rows, columns, 1),
                    RExpUtil.getColumn(nodepred.getValues(), rows, columns, i),
                    RExpUtil.getColumn(bestvar.getValues(), rows, columns, i),
                    RExpUtil.getColumn(xbestsplit.getValues(), rows, columns, i), segmentSchema);

            treeModels.add(treeModel);
        }

        Segmentation segmentation = MiningModelUtil.createSegmentation(MultipleModelMethodType.MAJORITY_VOTE,
                treeModels);

        Output output = ModelUtil.createProbabilityOutput(schema);

        MiningSchema miningSchema = ModelUtil.createMiningSchema(schema);

        MiningModel miningModel = new MiningModel(MiningFunctionType.CLASSIFICATION, miningSchema)
                .setSegmentation(segmentation).setOutput(output);

        return miningModel;
    }

    private <P extends Number> TreeModel encodeTreeModel(MiningFunctionType miningFunction,
            ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter,
            List<P> nodepred, List<? extends Number> bestvar, List<Double> xbestsplit, Schema schema) {
        Node root = new Node().setId("1").setPredicate(new True());

        encodeNode(root, 0, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, schema);

        MiningSchema miningSchema = ModelUtil.createMiningSchema(schema);

        TreeModel treeModel = new TreeModel(miningFunction, miningSchema, root)
                .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);

        return treeModel;
    }

    private <P extends Number> void encodeNode(Node node, int i, ScoreEncoder<P> scoreEncoder,
            List<? extends Number> leftDaughter, List<? extends Number> rightDaughter,
            List<? extends Number> bestvar, List<Double> xbestsplit, List<P> nodepred, Schema schema) {
        Predicate leftPredicate;
        Predicate rightPredicate;

        int var = ValueUtil.asInt(bestvar.get(i));
        if (var != 0) {
            Feature feature = schema.getFeature(var - 1);

            Double split = xbestsplit.get(i);

            if (feature instanceof ListFeature) {
                ListFeature listFeature = (ListFeature) feature;

                List<String> values = listFeature.getValues();

                leftPredicate = createSimpleSetPredicate(listFeature, selectValues(values, split, true));
                rightPredicate = createSimpleSetPredicate(listFeature, selectValues(values, split, false));
            } else

            if (feature instanceof ContinuousFeature) {
                String value = ValueUtil.formatValue(split);

                leftPredicate = createSimplePredicate(feature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
                rightPredicate = createSimplePredicate(feature, SimplePredicate.Operator.GREATER_THAN, value);
            } else

            {
                throw new IllegalArgumentException();
            }
        } else

        {
            P prediction = nodepred.get(i);

            node.setScore(scoreEncoder.encode(prediction));

            return;
        }

        int left = ValueUtil.asInt(leftDaughter.get(i));
        if (left != 0) {
            Node leftChild = new Node().setId(String.valueOf(left)).setPredicate(leftPredicate);

            encodeNode(leftChild, left - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit,
                    nodepred, schema);

            node.addNodes(leftChild);
        }

        int right = ValueUtil.asInt(rightDaughter.get(i));
        if (right != 0) {
            Node rightChild = new Node().setId(String.valueOf(right)).setPredicate(rightPredicate);

            encodeNode(rightChild, right - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit,
                    nodepred, schema);

            node.addNodes(rightChild);
        }
    }

    static <E> List<E> selectValues(List<E> values, Double split, boolean left) {
        UnsignedLong bits = toUnsignedLong(split.doubleValue());

        List<E> result = new ArrayList<>();

        for (int i = 0; i < values.size(); i++) {
            E value = values.get(i);

            boolean append;

            // Send "true" categories to the left
            if (left) {
                // Test if the least significant bit (LSB) is 1
                append = (bits.mod(RandomForestConverter.TWO)).equals(UnsignedLong.ONE);
            } else

            // Send all other categories to the right
            {
                // Test if the LSB is 0
                append = (bits.mod(RandomForestConverter.TWO)).equals(UnsignedLong.ZERO);
            } // End if

            if (append) {
                result.add(value);
            }

            bits = bits.dividedBy(RandomForestConverter.TWO);
        }

        return result;
    }

    static UnsignedLong toUnsignedLong(double value) {

        if (!DoubleMath.isMathematicalInteger(value)) {
            throw new IllegalArgumentException();
        }

        return UnsignedLong.fromLongBits((long) value);
    }

    static private interface ScoreEncoder<V extends Number> {

        String encode(V value);
    }

    private static final UnsignedLong TWO = UnsignedLong.valueOf(2L);
}