io.prestosql.sql.gen.InCodeGenerator.java Source code

Java tutorial

Introduction

Here is the source code for io.prestosql.sql.gen.InCodeGenerator.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.prestosql.sql.gen;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableSet;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.control.SwitchStatement.SwitchBuilder;
import io.airlift.bytecode.instruction.LabelNode;
import io.prestosql.metadata.FunctionRegistry;
import io.prestosql.metadata.Signature;
import io.prestosql.operator.scalar.ScalarFunctionImplementation;
import io.prestosql.spi.function.OperatorType;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.DateType;
import io.prestosql.spi.type.IntegerType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.relational.ConstantExpression;
import io.prestosql.sql.relational.RowExpression;
import io.prestosql.util.FastutilSetHelper;

import java.lang.invoke.MethodHandle;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Throwables.throwIfUnchecked;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue;
import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic;
import static io.airlift.bytecode.instruction.JumpInstruction.jump;
import static io.prestosql.spi.function.OperatorType.HASH_CODE;
import static io.prestosql.spi.function.OperatorType.INDETERMINATE;
import static io.prestosql.sql.gen.BytecodeUtils.ifWasNullPopAndGoto;
import static io.prestosql.sql.gen.BytecodeUtils.invoke;
import static io.prestosql.sql.gen.BytecodeUtils.loadConstant;
import static io.prestosql.util.FastutilSetHelper.toFastutilHashSet;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

public class InCodeGenerator implements BytecodeGenerator {
    private final FunctionRegistry registry;

    public InCodeGenerator(FunctionRegistry registry) {
        this.registry = requireNonNull(registry, "registry is null");
    }

    enum SwitchGenerationCase {
        DIRECT_SWITCH, HASH_SWITCH, SET_CONTAINS
    }

    @VisibleForTesting
    static SwitchGenerationCase checkSwitchGenerationCase(Type type, List<RowExpression> values) {
        if (values.size() > 32) {
            // 32 is chosen because
            // * SET_CONTAINS performs worst when smaller than but close to power of 2
            // * Benchmark shows performance of SET_CONTAINS is better at 50, but similar at 25.
            return SwitchGenerationCase.SET_CONTAINS;
        }

        if (!(type instanceof IntegerType || type instanceof BigintType || type instanceof DateType)) {
            return SwitchGenerationCase.HASH_SWITCH;
        }
        for (RowExpression expression : values) {
            // For non-constant expressions, they will be added to the default case in the generated switch code. They do not affect any of
            // the cases other than the default one. Therefore, it's okay to skip them when choosing between DIRECT_SWITCH and HASH_SWITCH.
            // Same argument applies for nulls.
            if (!(expression instanceof ConstantExpression)) {
                continue;
            }
            Object constant = ((ConstantExpression) expression).getValue();
            if (constant == null) {
                continue;
            }
            long longConstant = ((Number) constant).longValue();
            if (longConstant < Integer.MIN_VALUE || longConstant > Integer.MAX_VALUE) {
                return SwitchGenerationCase.HASH_SWITCH;
            }
        }
        return SwitchGenerationCase.DIRECT_SWITCH;
    }

    @Override
    public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext generatorContext,
            Type returnType, List<RowExpression> arguments) {
        List<RowExpression> values = arguments.subList(1, arguments.size());
        // empty IN statements are not allowed by the standard, and not possible here
        // the implementation assumes this condition is always met
        checkArgument(values.size() > 0, "values must not be empty");

        Type type = arguments.get(0).getType();
        Class<?> javaType = type.getJavaType();

        SwitchGenerationCase switchGenerationCase = checkSwitchGenerationCase(type, values);

        Signature hashCodeSignature = generatorContext.getRegistry().resolveOperator(HASH_CODE,
                ImmutableList.of(type));
        MethodHandle hashCodeFunction = generatorContext.getRegistry()
                .getScalarFunctionImplementation(hashCodeSignature).getMethodHandle();
        Signature isIndeterminateSignature = generatorContext.getRegistry().resolveOperator(INDETERMINATE,
                ImmutableList.of(type));
        ScalarFunctionImplementation isIndeterminateFunction = generatorContext.getRegistry()
                .getScalarFunctionImplementation(isIndeterminateSignature);

        ImmutableListMultimap.Builder<Integer, BytecodeNode> hashBucketsBuilder = ImmutableListMultimap.builder();
        ImmutableList.Builder<BytecodeNode> defaultBucket = ImmutableList.builder();
        ImmutableSet.Builder<Object> constantValuesBuilder = ImmutableSet.builder();

        for (RowExpression testValue : values) {
            BytecodeNode testBytecode = generatorContext.generate(testValue);

            if (isDeterminateConstant(testValue, isIndeterminateFunction.getMethodHandle())) {
                ConstantExpression constant = (ConstantExpression) testValue;
                Object object = constant.getValue();
                switch (switchGenerationCase) {
                case DIRECT_SWITCH:
                case SET_CONTAINS:
                    constantValuesBuilder.add(object);
                    break;
                case HASH_SWITCH:
                    try {
                        int hashCode = toIntExact(Long.hashCode((Long) hashCodeFunction.invoke(object)));
                        hashBucketsBuilder.put(hashCode, testBytecode);
                    } catch (Throwable throwable) {
                        throw new IllegalArgumentException(
                                "Error processing IN statement: error calculating hash code for " + object,
                                throwable);
                    }
                    break;
                default:
                    throw new IllegalArgumentException(
                            "Not supported switch generation case: " + switchGenerationCase);
                }
            } else {
                defaultBucket.add(testBytecode);
            }
        }
        ImmutableListMultimap<Integer, BytecodeNode> hashBuckets = hashBucketsBuilder.build();
        ImmutableSet<Object> constantValues = constantValuesBuilder.build();

        LabelNode end = new LabelNode("end");
        LabelNode match = new LabelNode("match");
        LabelNode noMatch = new LabelNode("noMatch");

        LabelNode defaultLabel = new LabelNode("default");

        Scope scope = generatorContext.getScope();
        Variable value = scope.createTempVariable(javaType);

        BytecodeNode switchBlock;
        Variable expression = scope.createTempVariable(int.class);
        SwitchBuilder switchBuilder = new SwitchBuilder().expression(expression);

        switch (switchGenerationCase) {
        case DIRECT_SWITCH:
            // A white-list is used to select types eligible for DIRECT_SWITCH.
            // For these types, it's safe to not use presto HASH_CODE and EQUAL operator.
            for (Object constantValue : constantValues) {
                switchBuilder.addCase(toIntExact((Long) constantValue), jump(match));
            }
            switchBuilder.defaultCase(jump(defaultLabel));
            switchBlock = new BytecodeBlock().comment("lookupSwitch(<stackValue>))")
                    .append(new IfStatement()
                            .condition(invokeStatic(InCodeGenerator.class, "isInteger", boolean.class, value))
                            .ifFalse(new BytecodeBlock().gotoLabel(defaultLabel)))
                    .append(expression.set(value.cast(int.class))).append(switchBuilder.build());
            break;
        case HASH_SWITCH:
            for (Map.Entry<Integer, Collection<BytecodeNode>> bucket : hashBuckets.asMap().entrySet()) {
                Collection<BytecodeNode> testValues = bucket.getValue();
                BytecodeBlock caseBlock = buildInCase(generatorContext, scope, type, match, defaultLabel, value,
                        testValues, false, isIndeterminateSignature, isIndeterminateFunction);
                switchBuilder.addCase(bucket.getKey(), caseBlock);
            }
            switchBuilder.defaultCase(jump(defaultLabel));
            Binding hashCodeBinding = generatorContext.getCallSiteBinder().bind(hashCodeFunction);
            switchBlock = new BytecodeBlock().comment("lookupSwitch(hashCode(<stackValue>))").getVariable(value)
                    .append(invoke(hashCodeBinding, hashCodeSignature))
                    .invokeStatic(Long.class, "hashCode", int.class, long.class).putVariable(expression)
                    .append(switchBuilder.build());
            break;
        case SET_CONTAINS:
            Set<?> constantValuesSet = toFastutilHashSet(constantValues, type, registry);
            Binding constant = generatorContext.getCallSiteBinder().bind(constantValuesSet,
                    constantValuesSet.getClass());

            switchBlock = new BytecodeBlock().comment("inListSet.contains(<stackValue>)")
                    .append(new IfStatement().condition(new BytecodeBlock().comment("value").getVariable(value)
                            .comment("set").append(loadConstant(constant))
                            // TODO: use invokeVirtual on the set instead. This requires swapping the two elements in the stack
                            .invokeStatic(FastutilSetHelper.class, "in", boolean.class,
                                    javaType.isPrimitive() ? javaType : Object.class, constantValuesSet.getClass()))
                            .ifTrue(jump(match)));
            break;
        default:
            throw new IllegalArgumentException("Not supported switch generation case: " + switchGenerationCase);
        }

        BytecodeBlock defaultCaseBlock = buildInCase(generatorContext, scope, type, match, noMatch, value,
                defaultBucket.build(), true, isIndeterminateSignature, isIndeterminateFunction)
                        .setDescription("default");

        BytecodeBlock block = new BytecodeBlock().comment("IN").append(generatorContext.generate(arguments.get(0)))
                .append(ifWasNullPopAndGoto(scope, end, boolean.class, javaType)).putVariable(value)
                .append(switchBlock).visitLabel(defaultLabel).append(defaultCaseBlock);

        BytecodeBlock matchBlock = new BytecodeBlock().setDescription("match").visitLabel(match)
                .append(generatorContext.wasNull().set(constantFalse())).push(true).gotoLabel(end);
        block.append(matchBlock);

        BytecodeBlock noMatchBlock = new BytecodeBlock().setDescription("noMatch").visitLabel(noMatch).push(false)
                .gotoLabel(end);
        block.append(noMatchBlock);

        block.visitLabel(end);

        return block;
    }

    public static boolean isInteger(long value) {
        return value == (int) value;
    }

    private static BytecodeBlock buildInCase(BytecodeGeneratorContext generatorContext, Scope scope, Type type,
            LabelNode matchLabel, LabelNode noMatchLabel, Variable value, Collection<BytecodeNode> testValues,
            boolean checkForNulls, Signature isIndeterminateSignature,
            ScalarFunctionImplementation isIndeterminateFunction) {
        Variable caseWasNull = null; // caseWasNull is set to true the first time a null in `testValues` is encountered
        if (checkForNulls) {
            caseWasNull = scope.createTempVariable(boolean.class);
        }

        BytecodeBlock caseBlock = new BytecodeBlock();

        if (checkForNulls) {
            caseBlock.putVariable(caseWasNull, false);
        }

        LabelNode elseLabel = new LabelNode("else");
        BytecodeBlock elseBlock = new BytecodeBlock().visitLabel(elseLabel);

        Variable wasNull = generatorContext.wasNull();
        if (checkForNulls) {
            // Consider followingexpression: "ARRAY[null] IN (ARRAY[1], ARRAY[2], ARRAY[3]) => NULL"
            // All lookup values will go to the SET_CONTAINS, since neither of them is indeterminate.
            // As ARRAY[null] is not among them, the code will fall through to the defaultCaseBlock.
            // Since there is no values in the defaultCaseBlock, the defaultCaseBlock will return FALSE.
            // That is incorrect. Doing an explicit check for indeterminate is required to correctly return NULL.
            if (testValues.isEmpty()) {
                elseBlock.append(
                        new BytecodeBlock().append(generatorContext.generateCall(isIndeterminateSignature.getName(),
                                isIndeterminateFunction, ImmutableList.of(value))).putVariable(wasNull));
            } else {
                elseBlock.append(wasNull.set(caseWasNull));
            }
        }

        elseBlock.gotoLabel(noMatchLabel);

        Signature equalsSignature = generatorContext.getRegistry().resolveOperator(OperatorType.EQUAL,
                ImmutableList.of(type, type));
        ScalarFunctionImplementation equalsFunction = generatorContext.getRegistry()
                .getScalarFunctionImplementation(equalsSignature);

        BytecodeNode elseNode = elseBlock;
        for (BytecodeNode testNode : testValues) {
            LabelNode testLabel = new LabelNode("test");
            IfStatement test = new IfStatement();

            BytecodeNode equalsCall = generatorContext.generateCall(equalsSignature.getName(), equalsFunction,
                    ImmutableList.of(value, testNode));

            test.condition().visitLabel(testLabel).append(equalsCall);

            if (checkForNulls) {
                IfStatement wasNullCheck = new IfStatement(
                        "if wasNull, set caseWasNull to true, clear wasNull, pop boolean, and goto next test value");
                wasNullCheck.condition(wasNull);
                wasNullCheck.ifTrue(new BytecodeBlock().append(caseWasNull.set(constantTrue()))
                        .append(wasNull.set(constantFalse())).pop(boolean.class).gotoLabel(elseLabel));
                test.condition().append(wasNullCheck);
            }

            test.ifTrue().gotoLabel(matchLabel);
            test.ifFalse(elseNode);

            elseNode = test;
            elseLabel = testLabel;
        }
        caseBlock.append(elseNode);
        return caseBlock;
    }

    private static boolean isDeterminateConstant(RowExpression expression, MethodHandle isIndeterminateFunction) {
        if (!(expression instanceof ConstantExpression)) {
            return false;
        }
        ConstantExpression constantExpression = (ConstantExpression) expression;
        Object value = constantExpression.getValue();
        boolean isNull = value == null;
        if (isNull) {
            return false;
        }
        try {
            return !(boolean) isIndeterminateFunction.invoke(value, false);
        } catch (Throwable t) {
            throwIfUnchecked(t);
            throw new RuntimeException(t);
        }
    }
}