sklearn_pandas.DataFrameMapper.java Source code

Java tutorial

Introduction

Here is the source code for sklearn_pandas.DataFrameMapper.java

Source

/*
 * Copyright (c) 2015 Villu Ruusmann
 *
 * This file is part of JPMML-SkLearn
 *
 * JPMML-SkLearn is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-SkLearn is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-SkLearn.  If not, see <http://www.gnu.org/licenses/>.
 */
package sklearn_pandas;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import net.razorvine.pickle.objects.ClassDict;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.jpmml.converter.Feature;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.FeatureMapper;
import org.jpmml.sklearn.TupleUtil;
import sklearn.Transformer;

public class DataFrameMapper extends ClassDict {

    public DataFrameMapper(String module, String name) {
        super(module, name);
    }

    public void encodeFeatures(FeatureMapper featureMapper) {
        List<Object[]> steps = getFeatures();

        for (int row = 0; row < steps.size(); row++) {
            Object[] step = steps.get(row);

            List<Feature> features = new ArrayList<>();

            List<String> names = getNameList(step);
            for (String name : names) {
                DataField dataField = featureMapper.createDataField(FieldName.create(name));

                Feature feature = new WildcardFeature(dataField);

                features.add(feature);
            }

            List<String> ids = new ArrayList<>(names);

            List<Transformer> transformers = getTransformerList(step);
            for (int column = 0; column < transformers.size(); column++) {
                Transformer transformer = transformers.get(column);

                for (Feature feature : features) {
                    featureMapper.updateType(feature.getName(), transformer.getOpType(), transformer.getDataType());
                }

                features = transformer.encodeFeatures(ids, features, featureMapper);
            }

            featureMapper.addRow(features);
        }
    }

    public List<Object[]> getFeatures() {
        return (List) get("features");
    }

    static private List<String> getNameList(Object[] feature) {
        Function<Object, String> function = new Function<Object, String>() {

            @Override
            public String apply(Object object) {

                if (object instanceof String) {
                    return (String) object;
                }

                throw new IllegalArgumentException(
                        "The key object (" + ClassDictUtil.formatClass(object) + ") is not a String");
            }
        };

        try {
            if (feature[0] instanceof List) {
                return new ArrayList<>(Lists.transform(((List) feature[0]), function));
            }

            return Collections.singletonList(function.apply(feature[0]));
        } catch (RuntimeException re) {
            throw new IllegalArgumentException("Invalid mapping key", re);
        }
    }

    static private List<Transformer> getTransformerList(Object[] feature) {
        Function<Object, Transformer> function = new Function<Object, Transformer>() {

            @Override
            public Transformer apply(Object object) {

                if (object instanceof Transformer) {
                    return (Transformer) object;
                }

                throw new IllegalArgumentException("The value object (" + ClassDictUtil.formatClass(object)
                        + ") is not a Transformer or is not a supported Transformer subclass");
            }
        };

        try {
            if (feature[1] == null) {
                return Collections.emptyList();
            } // End if

            if (feature[1] instanceof TransformerPipeline) {
                TransformerPipeline transformerPipeline = (TransformerPipeline) feature[1];

                List<Object[]> steps = transformerPipeline.getSteps();

                return new ArrayList<>(Lists.transform((List) TupleUtil.extractElement(steps, 1), function));
            } // End if

            if (feature[1] instanceof List) {
                return new ArrayList<>(Lists.transform((List) feature[1], function));
            }

            return Collections.singletonList(function.apply(feature[1]));
        } catch (RuntimeException re) {
            throw new IllegalArgumentException("Invalid mapping value", re);
        }
    }
}