com.cloudera.oryx.app.mllib.rdf.RDFUpdate.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.oryx.app.mllib.rdf.RDFUpdate.java

Source

/*
 * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */

package com.cloudera.oryx.app.mllib.rdf;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;

import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.AtomicLongMap;
import com.typesafe.config.Config;
import org.apache.hadoop.fs.Path;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.tree.configuration.Algo;
import org.apache.spark.mllib.tree.configuration.FeatureType;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.tree.model.Predict;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.mllib.tree.model.Split;
import org.dmg.pmml.Array;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.MissingValueStrategyType;
import org.dmg.pmml.Model;
import org.dmg.pmml.MultipleModelMethodType;
import org.dmg.pmml.Node;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.Segment;
import org.dmg.pmml.Segmentation;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.SimpleSetPredicate;
import org.dmg.pmml.TreeModel;
import org.dmg.pmml.True;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.collection.JavaConversions;

import com.cloudera.oryx.app.common.fn.MLFunctions;
import com.cloudera.oryx.app.pmml.AppPMMLUtils;
import com.cloudera.oryx.app.rdf.RDFPMMLUtils;
import com.cloudera.oryx.app.rdf.ToExampleFn;
import com.cloudera.oryx.app.rdf.example.Example;
import com.cloudera.oryx.app.rdf.tree.DecisionForest;
import com.cloudera.oryx.app.schema.CategoricalValueEncodings;
import com.cloudera.oryx.app.schema.InputSchema;
import com.cloudera.oryx.common.collection.Pair;
import com.cloudera.oryx.common.pmml.PMMLUtils;
import com.cloudera.oryx.common.random.RandomManager;
import com.cloudera.oryx.common.text.TextUtils;
import com.cloudera.oryx.ml.MLUpdate;
import com.cloudera.oryx.ml.param.HyperParamValues;
import com.cloudera.oryx.ml.param.HyperParams;

public final class RDFUpdate extends MLUpdate<String> {

    private static final Logger log = LoggerFactory.getLogger(RDFUpdate.class);

    private final int numTrees;
    private final List<HyperParamValues<?>> hyperParamValues;
    private final InputSchema inputSchema;

    public RDFUpdate(Config config) {
        super(config);
        numTrees = config.getInt("oryx.rdf.num-trees");
        Preconditions.checkArgument(numTrees >= 1);
        hyperParamValues = Arrays.asList(
                HyperParams.fromConfig(config, "oryx.rdf.hyperparams.max-split-candidates"),
                HyperParams.fromConfig(config, "oryx.rdf.hyperparams.max-depth"),
                HyperParams.fromConfig(config, "oryx.rdf.hyperparams.impurity"));

        inputSchema = new InputSchema(config);
        Preconditions.checkArgument(inputSchema.hasTarget());
    }

    @Override
    public List<HyperParamValues<?>> getHyperParameterValues() {
        return hyperParamValues;
    }

    @Override
    public PMML buildModel(JavaSparkContext sparkContext, JavaRDD<String> trainData, List<?> hyperParameters,
            Path candidatePath) {

        int maxSplitCandidates = (Integer) hyperParameters.get(0);
        int maxDepth = (Integer) hyperParameters.get(1);
        String impurity = (String) hyperParameters.get(2);
        Preconditions.checkArgument(maxSplitCandidates >= 2, "max-split-candidates must be at least 2");
        Preconditions.checkArgument(maxDepth > 0, "max-depth must be at least 1");

        JavaRDD<String[]> parsedRDD = trainData.map(MLFunctions.PARSE_FN);
        CategoricalValueEncodings categoricalValueEncodings = new CategoricalValueEncodings(
                getDistinctValues(parsedRDD));
        JavaRDD<LabeledPoint> trainPointData = parseToLabeledPointRDD(parsedRDD, categoricalValueEncodings);

        Map<Integer, Integer> categoryInfo = categoricalValueEncodings.getCategoryCounts();
        categoryInfo.remove(inputSchema.getTargetFeatureIndex()); // Don't specify target count
        // Need to translate indices to predictor indices
        Map<Integer, Integer> categoryInfoByPredictor = new HashMap<>(categoryInfo.size());
        for (Map.Entry<Integer, Integer> e : categoryInfo.entrySet()) {
            categoryInfoByPredictor.put(inputSchema.featureToPredictorIndex(e.getKey()), e.getValue());
        }

        int seed = RandomManager.getRandom().nextInt();

        RandomForestModel model;
        if (inputSchema.isClassification()) {
            int numTargetClasses = categoricalValueEncodings.getValueCount(inputSchema.getTargetFeatureIndex());
            model = RandomForest.trainClassifier(trainPointData, numTargetClasses, categoryInfoByPredictor,
                    numTrees, "auto", impurity, maxDepth, maxSplitCandidates, seed);
        } else {
            model = RandomForest.trainRegressor(trainPointData, categoryInfoByPredictor, numTrees, "auto", impurity,
                    maxDepth, maxSplitCandidates, seed);
        }

        List<Map<Integer, Long>> treeNodeIDCounts = treeNodeExampleCounts(trainPointData, model);
        Map<Integer, Long> predictorIndexCounts = predictorExampleCounts(trainPointData, model);

        return rdfModelToPMML(model, categoricalValueEncodings, maxDepth, maxSplitCandidates, impurity,
                treeNodeIDCounts, predictorIndexCounts);
    }

    @Override
    public double evaluate(JavaSparkContext sparkContext, PMML model, Path modelParentPath,
            JavaRDD<String> testData, JavaRDD<String> trainData) {
        RDFPMMLUtils.validatePMMLVsSchema(model, inputSchema);
        Pair<DecisionForest, CategoricalValueEncodings> forestAndEncoding = RDFPMMLUtils.read(model);
        DecisionForest forest = forestAndEncoding.getFirst();
        CategoricalValueEncodings valueEncodings = forestAndEncoding.getSecond();

        JavaRDD<Example> examplesRDD = testData.map(MLFunctions.PARSE_FN)
                .map(new ToExampleFn(inputSchema, valueEncodings));

        double eval;
        if (inputSchema.isClassification()) {
            double accuracy = Evaluation.accuracy(forest, examplesRDD);
            log.info("Accuracy: {}", accuracy);
            eval = accuracy;
        } else {
            double rmse = Evaluation.rmse(forest, examplesRDD);
            log.info("RMSE: {}", rmse);
            eval = 1.0 / rmse;
        }
        return eval;
    }

    private Map<Integer, Collection<String>> getDistinctValues(JavaRDD<String[]> parsedRDD) {
        final List<Integer> categoricalIndices = new ArrayList<>();
        for (int i = 0; i < inputSchema.getNumFeatures(); i++) {
            if (inputSchema.isCategorical(i)) {
                categoricalIndices.add(i);
            }
        }

        JavaRDD<Map<Integer, Collection<String>>> distinctValuesByPartition = parsedRDD
                .mapPartitions(new FlatMapFunction<Iterator<String[]>, Map<Integer, Collection<String>>>() {
                    @Override
                    public Iterable<Map<Integer, Collection<String>>> call(Iterator<String[]> data) {
                        Map<Integer, Collection<String>> distinctCategoricalValues = new HashMap<>();
                        for (int i : categoricalIndices) {
                            distinctCategoricalValues.put(i, new HashSet<String>());
                        }
                        while (data.hasNext()) {
                            String[] datum = data.next();
                            for (Map.Entry<Integer, Collection<String>> e : distinctCategoricalValues.entrySet()) {
                                e.getValue().add(datum[e.getKey()]);
                            }
                        }
                        return Collections.singletonList(distinctCategoricalValues);
                    }
                });

        return distinctValuesByPartition.reduce(
                new Function2<Map<Integer, Collection<String>>, Map<Integer, Collection<String>>, Map<Integer, Collection<String>>>() {
                    @Override
                    public Map<Integer, Collection<String>> call(Map<Integer, Collection<String>> v1,
                            Map<Integer, Collection<String>> v2) {
                        for (Map.Entry<Integer, Collection<String>> e : v1.entrySet()) {
                            e.getValue().addAll(v2.get(e.getKey()));
                        }
                        return v1;
                    }
                });
    }

    private JavaRDD<LabeledPoint> parseToLabeledPointRDD(JavaRDD<String[]> parsedRDD,
            final CategoricalValueEncodings categoricalValueEncodings) {

        return parsedRDD.map(new Function<String[], LabeledPoint>() {
            @Override
            public LabeledPoint call(String[] data) {
                double[] features = new double[inputSchema.getNumPredictors()];
                double target = Double.NaN;
                for (int featureIndex = 0; featureIndex < data.length; featureIndex++) {
                    double encoded;
                    if (inputSchema.isNumeric(featureIndex)) {
                        encoded = Double.parseDouble(data[featureIndex]);
                    } else if (inputSchema.isCategorical(featureIndex)) {
                        Map<String, Integer> valueEncoding = categoricalValueEncodings
                                .getValueEncodingMap(featureIndex);
                        encoded = valueEncoding.get(data[featureIndex]);
                    } else {
                        continue;
                    }
                    if (inputSchema.isTarget(featureIndex)) {
                        target = encoded;
                    } else {
                        features[inputSchema.featureToPredictorIndex(featureIndex)] = encoded;
                    }
                }
                Preconditions.checkState(!Double.isNaN(target));
                return new LabeledPoint(target, Vectors.dense(features));
            }
        });
    }

    /**
     * @param trainPointData data to run down trees
     * @param model random decision forest model to count on
     * @return maps of node IDs to the count of training examples that reached that node, one
     *  per tree in the model
     * @see #predictorExampleCounts(JavaRDD,RandomForestModel)
     */
    private static List<Map<Integer, Long>> treeNodeExampleCounts(JavaRDD<LabeledPoint> trainPointData,
            final RandomForestModel model) {
        List<AtomicLongMap<Integer>> maps = trainPointData
                .mapPartitions(new FlatMapFunction<Iterator<LabeledPoint>, List<AtomicLongMap<Integer>>>() {
                    @Override
                    public Iterable<List<AtomicLongMap<Integer>>> call(Iterator<LabeledPoint> data) {
                        DecisionTreeModel[] trees = model.trees();
                        int numTrees = trees.length;
                        List<AtomicLongMap<Integer>> treeNodeIDCounts = new ArrayList<>(numTrees);
                        for (int i = 0; i < numTrees; i++) {
                            treeNodeIDCounts.add(AtomicLongMap.<Integer>create());
                        }
                        while (data.hasNext()) {
                            LabeledPoint datum = data.next();
                            double[] featureVector = datum.features().toArray();
                            for (int i = 0; i < trees.length; i++) {
                                DecisionTreeModel tree = trees[i];
                                AtomicLongMap<Integer> nodeIDCount = treeNodeIDCounts.get(i);
                                org.apache.spark.mllib.tree.model.Node node = tree.topNode();
                                // This logic cloned from Node.predict:
                                while (!node.isLeaf()) {
                                    // Count node ID
                                    nodeIDCount.incrementAndGet(node.id());
                                    Split split = node.split().get();
                                    int featureIndex = split.feature();
                                    node = nextNode(featureVector, node, split, featureIndex);
                                }
                                nodeIDCount.incrementAndGet(node.id());
                            }
                        }
                        return Collections.singleton(treeNodeIDCounts);
                    }
                })
                .reduce(new Function2<List<AtomicLongMap<Integer>>, List<AtomicLongMap<Integer>>, List<AtomicLongMap<Integer>>>() {
                    @Override
                    public List<AtomicLongMap<Integer>> call(List<AtomicLongMap<Integer>> a,
                            List<AtomicLongMap<Integer>> b) {
                        Preconditions.checkArgument(a.size() == b.size());
                        for (int i = 0; i < a.size(); i++) {
                            merge(a.get(i), b.get(i));
                        }
                        return a;
                    }
                });

        List<Map<Integer, Long>> result = new ArrayList<>(maps.size());
        for (AtomicLongMap<Integer> map : maps) {
            result.add(map.asMap());
        }
        return result;
    }

    /**
     * @param trainPointData data to run down trees
     * @param model random decision forest model to count on
     * @return map of predictor index to the number of training examples that reached a
     *  node whose decision is based on that feature. The index is among predictors, not all
     *  features, since there are fewer predictors than features. That is, the index will
     *  match the one used in the {@link RandomForestModel}.
     */
    private static Map<Integer, Long> predictorExampleCounts(JavaRDD<LabeledPoint> trainPointData,
            final RandomForestModel model) {
        return trainPointData.mapPartitions(new FlatMapFunction<Iterator<LabeledPoint>, AtomicLongMap<Integer>>() {
            @Override
            public Iterable<AtomicLongMap<Integer>> call(Iterator<LabeledPoint> data) {
                AtomicLongMap<Integer> featureIndexCount = AtomicLongMap.create();
                while (data.hasNext()) {
                    LabeledPoint datum = data.next();
                    double[] featureVector = datum.features().toArray();
                    for (DecisionTreeModel tree : model.trees()) {
                        org.apache.spark.mllib.tree.model.Node node = tree.topNode();
                        // This logic cloned from Node.predict:
                        while (!node.isLeaf()) {
                            Split split = node.split().get();
                            int featureIndex = split.feature();
                            // Count feature
                            featureIndexCount.incrementAndGet(featureIndex);
                            node = nextNode(featureVector, node, split, featureIndex);
                        }
                    }
                }
                return Collections.singleton(featureIndexCount);
            }
        }).reduce(new Function2<AtomicLongMap<Integer>, AtomicLongMap<Integer>, AtomicLongMap<Integer>>() {
            @Override
            public AtomicLongMap<Integer> call(AtomicLongMap<Integer> a, AtomicLongMap<Integer> b) {
                return merge(a, b);
            }
        }).asMap();
    }

    private static org.apache.spark.mllib.tree.model.Node nextNode(double[] featureVector,
            org.apache.spark.mllib.tree.model.Node node, Split split, int featureIndex) {
        double featureValue = featureVector[featureIndex];
        if (split.featureType().equals(FeatureType.Continuous())) {
            if (featureValue <= split.threshold()) {
                return node.leftNode().get();
            } else {
                return node.rightNode().get();
            }
        } else {
            if (split.categories().contains(featureValue)) {
                return node.leftNode().get();
            } else {
                return node.rightNode().get();
            }
        }
    }

    private static AtomicLongMap<Integer> merge(AtomicLongMap<Integer> a, AtomicLongMap<Integer> b) {
        for (Map.Entry<Integer, Long> e : b.asMap().entrySet()) {
            a.addAndGet(e.getKey(), e.getValue());
        }
        return a;
    }

    private PMML rdfModelToPMML(RandomForestModel rfModel, CategoricalValueEncodings categoricalValueEncodings,
            int maxDepth, int maxSplitCandidates, String impurity, List<Map<Integer, Long>> nodeIDCounts,
            Map<Integer, Long> predictorIndexCounts) {

        boolean classificationTask = rfModel.algo().equals(Algo.Classification());
        Preconditions.checkState(classificationTask == inputSchema.isClassification());

        DecisionTreeModel[] trees = rfModel.trees();

        Model model;
        if (trees.length == 1) {
            model = toTreeModel(trees[0], categoricalValueEncodings, nodeIDCounts.get(0));
        } else {
            MiningModel miningModel = new MiningModel();
            model = miningModel;
            MultipleModelMethodType multipleModelMethodType = classificationTask
                    ? MultipleModelMethodType.WEIGHTED_MAJORITY_VOTE
                    : MultipleModelMethodType.WEIGHTED_AVERAGE;
            List<Segment> segments = new ArrayList<>(trees.length);
            for (int treeID = 0; treeID < trees.length; treeID++) {
                TreeModel treeModel = toTreeModel(trees[treeID], categoricalValueEncodings,
                        nodeIDCounts.get(treeID));
                Segment segment = new Segment();
                segment.setId(Integer.toString(treeID));
                segment.setPredicate(new True());
                segment.setModel(treeModel);
                segment.setWeight(1.0); // No weights in MLlib impl now
                segments.add(segment);
            }
            miningModel.setSegmentation(new Segmentation(multipleModelMethodType, segments));
        }

        model.setFunctionName(
                classificationTask ? MiningFunctionType.CLASSIFICATION : MiningFunctionType.REGRESSION);

        double[] importances = countsToImportances(predictorIndexCounts);
        model.setMiningSchema(AppPMMLUtils.buildMiningSchema(inputSchema, importances));
        DataDictionary dictionary = AppPMMLUtils.buildDataDictionary(inputSchema, categoricalValueEncodings);

        PMML pmml = PMMLUtils.buildSkeletonPMML();
        pmml.setDataDictionary(dictionary);
        pmml.getModels().add(model);

        AppPMMLUtils.addExtension(pmml, "maxDepth", maxDepth);
        AppPMMLUtils.addExtension(pmml, "maxSplitCandidates", maxSplitCandidates);
        AppPMMLUtils.addExtension(pmml, "impurity", impurity);

        return pmml;
    }

    private TreeModel toTreeModel(DecisionTreeModel dtModel, CategoricalValueEncodings categoricalValueEncodings,
            Map<Integer, Long> nodeIDCounts) {

        boolean classificationTask = dtModel.algo().equals(Algo.Classification());
        Preconditions.checkState(classificationTask == inputSchema.isClassification());

        Node root = new Node();
        root.setId("r");

        Queue<Node> modelNodes = new ArrayDeque<>();
        modelNodes.add(root);

        Queue<Pair<org.apache.spark.mllib.tree.model.Node, Split>> treeNodes = new ArrayDeque<>();
        treeNodes.add(new Pair<>(dtModel.topNode(), (Split) null));

        while (!treeNodes.isEmpty()) {

            Pair<org.apache.spark.mllib.tree.model.Node, Split> treeNodePredicate = treeNodes.remove();
            Node modelNode = modelNodes.remove();

            // This is the decision that got us here from the parent, if any;
            // not the predicate at this node
            Predicate predicate = buildPredicate(treeNodePredicate.getSecond(), categoricalValueEncodings);
            modelNode.setPredicate(predicate);

            org.apache.spark.mllib.tree.model.Node treeNode = treeNodePredicate.getFirst();
            long nodeCount = nodeIDCounts.get(treeNode.id());
            modelNode.setRecordCount((double) nodeCount);

            if (treeNode.isLeaf()) {

                Predict prediction = treeNode.predict();
                int targetEncodedValue = (int) prediction.predict();
                if (classificationTask) {
                    Map<Integer, String> targetEncodingToValue = categoricalValueEncodings
                            .getEncodingValueMap(inputSchema.getTargetFeatureIndex());
                    String predictedCategoricalValue = targetEncodingToValue.get(targetEncodedValue);
                    double confidence = prediction.prob();
                    Preconditions.checkState(confidence >= 0.0 && confidence <= 1.0);
                    // Slightly faked 'record' count; taken as the probability of the positive class
                    // times record count at the node
                    long pseudoSDRecordCount = Math.round(confidence * nodeCount);
                    ScoreDistribution distribution = new ScoreDistribution(predictedCategoricalValue,
                            pseudoSDRecordCount);
                    distribution.setConfidence(confidence);
                    modelNode.getScoreDistributions().add(distribution);
                } else {
                    modelNode.setScore(Double.toString(targetEncodedValue));
                }

            } else {

                Split split = treeNode.split().get();

                Node positiveModelNode = new Node();
                positiveModelNode.setId(modelNode.getId() + '+');
                modelNode.getNodes().add(positiveModelNode);
                Node negativeModelNode = new Node();
                negativeModelNode.setId(modelNode.getId() + '-');
                modelNode.getNodes().add(negativeModelNode);

                org.apache.spark.mllib.tree.model.Node rightTreeNode = treeNode.rightNode().get();
                org.apache.spark.mllib.tree.model.Node leftTreeNode = treeNode.leftNode().get();

                boolean defaultRight = nodeIDCounts.get(rightTreeNode.id()) > nodeIDCounts.get(leftTreeNode.id());
                modelNode.setDefaultChild(defaultRight ? positiveModelNode.getId() : negativeModelNode.getId());

                // Right node is "positive", so carries the predicate. It must evaluate first
                // and therefore come first in the tree
                modelNodes.add(positiveModelNode);
                modelNodes.add(negativeModelNode);
                treeNodes.add(new Pair<>(rightTreeNode, split));
                treeNodes.add(new Pair<>(leftTreeNode, (Split) null));

            }

        }

        TreeModel treeModel = new TreeModel();
        treeModel.setNode(root);
        treeModel.setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
        treeModel.setMissingValueStrategy(MissingValueStrategyType.DEFAULT_CHILD);
        return treeModel;
    }

    private Predicate buildPredicate(Split split, CategoricalValueEncodings categoricalValueEncodings) {
        if (split == null) {
            // Left child always applies, but is evaluated second
            return new True();
        }

        int featureIndex = inputSchema.predictorToFeatureIndex(split.feature());
        FieldName fieldName = FieldName.create(inputSchema.getFeatureNames().get(featureIndex));

        if (split.featureType().equals(FeatureType.Categorical())) {
            // Note that categories in MLlib model select the *left* child but the
            // convention here will be that the predicate selects the *right* child
            // So the predicate will evaluate "not in" this set
            // More ugly casting
            @SuppressWarnings("unchecked")
            List<Double> javaCategories = (List<Double>) (List<?>) JavaConversions
                    .seqAsJavaList(split.categories());
            Set<Integer> negativeEncodings = new HashSet<>(javaCategories.size());
            for (double category : javaCategories) {
                negativeEncodings.add((int) category);
            }

            Map<Integer, String> encodingToValue = categoricalValueEncodings.getEncodingValueMap(featureIndex);
            List<String> negativeValues = new ArrayList<>();
            for (int negativeEncoding : negativeEncodings) {
                negativeValues.add(encodingToValue.get(negativeEncoding));
            }

            String joinedValues = TextUtils.joinPMMLDelimited(negativeValues);
            return new SimpleSetPredicate(fieldName, SimpleSetPredicate.BooleanOperator.IS_NOT_IN,
                    new Array(Array.Type.STRING, joinedValues));

        } else {
            // For MLlib, left means <= threshold, so right means >
            SimplePredicate numericPredicate = new SimplePredicate(fieldName,
                    SimplePredicate.Operator.GREATER_THAN);
            numericPredicate.setValue(Double.toString(split.threshold()));
            return numericPredicate;
        }
    }

    private double[] countsToImportances(Map<Integer, Long> predictorIndexCounts) {
        double[] importances = new double[inputSchema.getNumPredictors()];
        long total = 0L;
        for (long count : predictorIndexCounts.values()) {
            total += count;
        }
        Preconditions.checkArgument(total > 0);
        for (Map.Entry<Integer, Long> e : predictorIndexCounts.entrySet()) {
            importances[e.getKey()] = (double) e.getValue() / total;
        }
        return importances;
    }

}