org.jpmml.evaluator.ArgumentUtil.java Source code

Java tutorial

Introduction

Here is the source code for org.jpmml.evaluator.ArgumentUtil.java

Source

/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see <http://www.gnu.org/licenses/>.
 */
package org.jpmml.evaluator;

import java.util.Collections;
import java.util.List;

import com.google.common.base.Function;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableRangeSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Range;
import com.google.common.collect.RangeSet;
import com.google.common.collect.TreeRangeSet;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Interval;
import org.dmg.pmml.InvalidValueTreatmentMethodType;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutlierTreatmentMethodType;
import org.dmg.pmml.TypeDefinitionField;
import org.dmg.pmml.Value;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.UnsupportedFeatureException;

public class ArgumentUtil {

    private ArgumentUtil() {
    }

    @SuppressWarnings(value = { "unused" })
    static public FieldValue prepare(DataField dataField, MiningField miningField, Object value) {

        if (value != null) {
            DataType dataType = dataField.getDataType();

            try {
                value = TypeUtil.parseOrCast(dataType, value);
            } catch (IllegalArgumentException iae) {
                // Ignored
            }
        }

        outlierTreatment: if (isOutlier(dataField, miningField, value)) {
            OutlierTreatmentMethodType outlierTreatmentMethod = miningField.getOutlierTreatment();

            switch (outlierTreatmentMethod) {
            case AS_IS:
                break;
            case AS_MISSING_VALUES:
                value = null;
                break;
            case AS_EXTREME_VALUES: {
                Double lowValue = miningField.getLowValue();
                Double highValue = miningField.getHighValue();

                if (lowValue == null || highValue == null) {
                    throw new InvalidFeatureException(miningField);
                } // End if

                if ((lowValue).compareTo(highValue) > 0) {
                    throw new InvalidFeatureException(miningField);
                }

                Double doubleValue = (Double) TypeUtil.parseOrCast(DataType.DOUBLE, value);

                if (TypeUtil.compare(DataType.DOUBLE, doubleValue, lowValue) < 0) {
                    value = lowValue;
                } else

                if (TypeUtil.compare(DataType.DOUBLE, doubleValue, highValue) > 0) {
                    value = highValue;
                }
            }
                break;
            default:
                throw new UnsupportedFeatureException(miningField, outlierTreatmentMethod);
            }
        } // End if

        missingValueTreatment: if (isMissing(dataField, value)) {
            value = miningField.getMissingValueReplacement();

            if (value != null) {
                break missingValueTreatment;
            }

            return null;
        } // End if

        invalidValueTreatment: if (isInvalid(dataField, miningField, value)) {
            InvalidValueTreatmentMethodType invalidValueTreatmentMethod = miningField.getInvalidValueTreatment();

            switch (invalidValueTreatmentMethod) {
            case RETURN_INVALID:
                throw new InvalidResultException(miningField);
            case AS_IS:
                break invalidValueTreatment;
            case AS_MISSING: {
                value = miningField.getMissingValueReplacement();
                if (value != null) {
                    break invalidValueTreatment;
                }

                return null;
            }
            default:
                throw new UnsupportedFeatureException(miningField, invalidValueTreatmentMethod);
            }
        }

        return FieldValueUtil.create(dataField, miningField, value);
    }

    static public boolean isOutlier(DataField dataField, MiningField miningField, Object value) {

        if (value == null) {
            return false;
        }

        List<Interval> intervals = dataField.getIntervals();

        OpType opType = miningField.getOptype();
        if (opType == null) {
            opType = dataField.getOptype();
        }

        switch (opType) {
        case CONTINUOUS: {
            if (intervals.size() > 0) {
                RangeSet<Double> validRange = CacheUtil.getValue(dataField, ArgumentUtil.validRangeCache);

                Range<Double> validRangeSpan = validRange.span();

                Double doubleValue = (Double) TypeUtil.parseOrCast(DataType.DOUBLE, value);

                return !validRangeSpan.contains(doubleValue);
            }
        }
            break;
        case CATEGORICAL:
        case ORDINAL:
            break;
        default:
            throw new UnsupportedFeatureException(miningField, opType);
        }

        return false;
    }

    static public boolean isMissing(DataField dataField, Object value) {

        if (value == null) {
            return true;
        }

        DataType dataType = dataField.getDataType();

        List<Value> fieldValues = dataField.getValues();
        for (Value fieldValue : fieldValues) {
            Value.Property property = fieldValue.getProperty();

            switch (property) {
            case MISSING: {
                boolean equals = equals(dataType, value, fieldValue.getValue());
                if (equals) {
                    return true;
                }
            }
                break;
            default:
                break;
            }
        }

        return false;
    }

    static public boolean isInvalid(DataField dataField, MiningField miningField, Object value) {

        if (value == null) {
            return false;
        }

        return !isValid(dataField, miningField, value);
    }

    @SuppressWarnings(value = "fallthrough")
    static public boolean isValid(DataField dataField, MiningField miningField, Object value) {

        if (value == null) {
            return false;
        }

        DataType dataType = dataField.getDataType();

        List<Interval> intervals = dataField.getIntervals();

        OpType opType = miningField.getOptype();
        if (opType == null) {
            opType = dataField.getOptype();
        }

        switch (opType) {
        case CONTINUOUS: {
            // "If intervals are present, then a value that is outside the intervals is considered invalid"
            if (intervals.size() > 0) {
                RangeSet<Double> validRanges = CacheUtil.getValue(dataField, ArgumentUtil.validRangeCache);

                Double doubleValue = (Double) TypeUtil.parseOrCast(DataType.DOUBLE, value);

                return validRanges.contains(doubleValue);
            }
        }
        // Falls through
        case CATEGORICAL:
        case ORDINAL: {
            // "Intervals are not allowed for non-continuous fields"
            if (intervals.size() > 0) {
                throw new InvalidFeatureException(dataField);
            }

            int validValueCount = 0;

            List<Value> fieldValues = dataField.getValues();
            for (Value fieldValue : fieldValues) {
                Value.Property property = fieldValue.getProperty();

                switch (property) {
                case VALID: {
                    validValueCount += 1;

                    boolean equals = equals(dataType, value, fieldValue.getValue());
                    if (equals) {
                        return true;
                    }
                }
                    break;
                case INVALID:
                case MISSING: {
                    boolean equals = equals(dataType, value, fieldValue.getValue());
                    if (equals) {
                        return false;
                    }
                }
                    break;
                default:
                    throw new UnsupportedFeatureException(fieldValue, property);
                }
            }

            // "If a field contains at least one Value element where the value of property is valid, then the set of Value elements completely defines the set of valid values"
            if (validValueCount > 0) {
                return false;
            }

            // "Any value is valid by default"
            return true;
        }
        default:
            throw new UnsupportedFeatureException(miningField, opType);
        }
    }

    static public Value getValidValue(TypeDefinitionField field, Object value) {
        DataType dataType = field.getDataType();

        List<Value> fieldValues = field.getValues();
        for (Value fieldValue : fieldValues) {
            Value.Property property = fieldValue.getProperty();

            switch (property) {
            case VALID: {
                boolean equals = equals(dataType, value, fieldValue.getValue());
                if (equals) {
                    return fieldValue;
                }
            }
                break;
            default:
                break;
            }
        }

        return null;
    }

    static public List<Value> getValidValues(TypeDefinitionField field) {
        List<Value> fieldValues = field.getValues();
        if (fieldValues.isEmpty()) {
            return Collections.emptyList();
        }

        List<Value> result = Lists.newArrayList();

        for (Value fieldValue : fieldValues) {
            Value.Property property = fieldValue.getProperty();

            switch (property) {
            case VALID:
                result.add(fieldValue);
                break;
            default:
                break;
            }
        }

        return result;
    }

    static private boolean equals(DataType dataType, Object value, String referenceValue) {

        try {
            return TypeUtil.equals(dataType, value, TypeUtil.parseOrCast(dataType, referenceValue));
        } catch (IllegalArgumentException iae) {

            // The String representation of invalid or missing values (eg. "N/A") may not be parseable to the requested representation
            try {
                return TypeUtil.equals(DataType.STRING, value, referenceValue);
            } catch (TypeCheckException tce) {
                // Ignored
            }

            throw iae;
        }
    }

    static public List<String> getTargetCategories(TypeDefinitionField field) {
        return CacheUtil.getValue(field, ArgumentUtil.targetCategoryCache);
    }

    static private RangeSet<Double> parseValidRanges(DataField dataField) {
        RangeSet<Double> result = TreeRangeSet.create();

        List<Interval> intervals = dataField.getIntervals();
        for (Interval interval : intervals) {
            Range<Double> range = DiscretizationUtil.toRange(interval);

            result.add(range);
        }

        return result;
    }

    private static final LoadingCache<TypeDefinitionField, List<String>> targetCategoryCache = CacheBuilder
            .newBuilder().weakKeys().build(new CacheLoader<TypeDefinitionField, List<String>>() {

                @Override
                public List<String> load(TypeDefinitionField field) {
                    List<Value> values = getValidValues(field);

                    Function<Value, String> function = new Function<Value, String>() {

                        @Override
                        public String apply(Value value) {
                            String result = value.getValue();
                            if (result == null) {
                                throw new InvalidFeatureException(value);
                            }

                            return result;
                        }
                    };

                    return ImmutableList.copyOf(Iterables.transform(values, function));
                }
            });

    private static final LoadingCache<DataField, RangeSet<Double>> validRangeCache = CacheBuilder.newBuilder()
            .weakKeys().build(new CacheLoader<DataField, RangeSet<Double>>() {

                @Override
                public RangeSet<Double> load(DataField dataField) {
                    return ImmutableRangeSet.copyOf(parseValidRanges(dataField));
                }
            });
}