org.apache.hadoop.hive.ql.optimizer.calcite.translator.ExprNodeConverter.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.hadoop.hive.ql.optimizer.calcite.translator.ExprNodeConverter.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.translator;

import java.math.BigDecimal;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;

import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexFieldCollation;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexOver;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.rex.RexWindow;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.common.type.HiveIntervalDayTime;
import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.optimizer.ConstantPropagateProcFactory;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.ASTConverter.RexVisitor;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.ASTConverter.Schema;
import org.apache.hadoop.hive.ql.parse.ASTNode;
import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.NullOrder;
import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.Order;
import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.OrderExpression;
import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.OrderSpec;
import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.PartitionExpression;
import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.PartitionSpec;
import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.PartitioningSpec;
import org.apache.hadoop.hive.ql.parse.WindowingSpec.BoundarySpec;
import org.apache.hadoop.hive.ql.parse.WindowingSpec.Direction;
import org.apache.hadoop.hive.ql.parse.WindowingSpec.WindowFrameSpec;
import org.apache.hadoop.hive.ql.parse.WindowingSpec.WindowFunctionSpec;
import org.apache.hadoop.hive.ql.parse.WindowingSpec.WindowSpec;
import org.apache.hadoop.hive.ql.parse.WindowingSpec.WindowType;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeFieldDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.collect.ImmutableSet;

/*
 * convert a RexNode to an ExprNodeDesc
 */
public class ExprNodeConverter extends RexVisitorImpl<ExprNodeDesc> {

    private final boolean foldExpr;
    private final String tabAlias;
    private final RelDataType inputRowType;
    private final ImmutableSet<Integer> inputVCols;
    private final List<WindowFunctionSpec> windowFunctionSpecs = new ArrayList<>();
    private final RelDataTypeFactory dTFactory;
    protected final Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
    private static long uniqueCounter = 0;

    public ExprNodeConverter(String tabAlias, RelDataType inputRowType, Set<Integer> vCols,
            RelDataTypeFactory dTFactory) {
        this(tabAlias, null, inputRowType, null, vCols, dTFactory, false);
    }

    public ExprNodeConverter(String tabAlias, RelDataType inputRowType, Set<Integer> vCols,
            RelDataTypeFactory dTFactory, boolean foldExpr) {
        this(tabAlias, null, inputRowType, null, vCols, dTFactory, foldExpr);
    }

    public ExprNodeConverter(String tabAlias, String columnAlias, RelDataType inputRowType,
            RelDataType outputRowType, Set<Integer> inputVCols, RelDataTypeFactory dTFactory) {
        this(tabAlias, columnAlias, inputRowType, outputRowType, inputVCols, dTFactory, false);
    }

    public ExprNodeConverter(String tabAlias, String columnAlias, RelDataType inputRowType,
            RelDataType outputRowType, Set<Integer> inputVCols, RelDataTypeFactory dTFactory, boolean foldExpr) {
        super(true);
        this.tabAlias = tabAlias;
        this.inputRowType = inputRowType;
        this.inputVCols = ImmutableSet.copyOf(inputVCols);
        this.dTFactory = dTFactory;
        this.foldExpr = foldExpr;
    }

    public List<WindowFunctionSpec> getWindowFunctionSpec() {
        return this.windowFunctionSpecs;
    }

    @Override
    public ExprNodeDesc visitInputRef(RexInputRef inputRef) {
        RelDataTypeField f = inputRowType.getFieldList().get(inputRef.getIndex());
        return new ExprNodeColumnDesc(TypeConverter.convert(f.getType()), f.getName(), tabAlias,
                inputVCols.contains(inputRef.getIndex()));
    }

    /**
     * TODO: Handle 1) cast 2), Windowing Agg Call
     */

    @Override
    /*
     * Handles expr like struct(key,value).key
     * Follows same rules as TypeCheckProcFactory::getXpathOrFuncExprNodeDesc()
     * which is equivalent version of parsing such an expression from AST
     */
    public ExprNodeDesc visitFieldAccess(RexFieldAccess fieldAccess) {
        ExprNodeDesc parent = fieldAccess.getReferenceExpr().accept(this);
        String child = fieldAccess.getField().getName();
        TypeInfo parentType = parent.getTypeInfo();
        // Allow accessing a field of list element structs directly from a list
        boolean isList = (parentType.getCategory() == ObjectInspector.Category.LIST);
        if (isList) {
            parentType = ((ListTypeInfo) parentType).getListElementTypeInfo();
        }
        TypeInfo t = ((StructTypeInfo) parentType).getStructFieldTypeInfo(child);
        return new ExprNodeFieldDesc(t, parent, child, isList);
    }

    @Override
    public ExprNodeDesc visitCall(RexCall call) {
        ExprNodeDesc gfDesc = null;

        if (!deep) {
            return null;
        }

        List<ExprNodeDesc> args = new LinkedList<ExprNodeDesc>();
        if (call.getKind() == SqlKind.EXTRACT) {
            // Extract on date: special handling since function in Hive does
            // include <time_unit>. Observe that <time_unit> information
            // is implicit in the function name, thus translation will
            // proceed correctly if we just ignore the <time_unit>
            args.add(call.operands.get(1).accept(this));
        } else if (call.getKind() == SqlKind.FLOOR && call.operands.size() == 2) {
            // Floor on date: special handling since function in Hive does
            // include <time_unit>. Observe that <time_unit> information
            // is implicit in the function name, thus translation will
            // proceed correctly if we just ignore the <time_unit>
            args.add(call.operands.get(0).accept(this));
        } else {
            for (RexNode operand : call.operands) {
                args.add(operand.accept(this));
            }
        }

        // If Call is a redundant cast then bail out. Ex: cast(true)BOOLEAN
        if (call.isA(SqlKind.CAST) && (call.operands.size() == 1)
                && SqlTypeUtil.equalSansNullability(dTFactory, call.getType(), call.operands.get(0).getType())) {
            return args.get(0);
        } else {
            GenericUDF hiveUdf = SqlFunctionConverter.getHiveUDF(call.getOperator(), call.getType(), args.size());
            if (hiveUdf == null) {
                throw new RuntimeException("Cannot find UDF for " + call.getType() + " " + call.getOperator() + "["
                        + call.getOperator().getKind() + "]/" + args.size());
            }
            try {
                gfDesc = ExprNodeGenericFuncDesc.newInstance(hiveUdf, args);
            } catch (UDFArgumentException e) {
                LOG.error("Failed to instantiate udf: ", e);
                throw new RuntimeException(e);
            }
        }

        // Try to fold if it is a constant expression
        if (foldExpr && RexUtil.isConstant(call)) {
            ExprNodeDesc constantExpr = ConstantPropagateProcFactory.foldExpr((ExprNodeGenericFuncDesc) gfDesc);
            if (constantExpr != null) {
                gfDesc = constantExpr;
            }
        }

        return gfDesc;
    }

    @Override
    public ExprNodeDesc visitLiteral(RexLiteral literal) {
        RelDataType lType = literal.getType();

        if (RexLiteral.value(literal) == null) {
            switch (literal.getType().getSqlTypeName()) {
            case BOOLEAN:
                return new ExprNodeConstantDesc(TypeInfoFactory.booleanTypeInfo, null);
            case TINYINT:
                return new ExprNodeConstantDesc(TypeInfoFactory.byteTypeInfo, null);
            case SMALLINT:
                return new ExprNodeConstantDesc(TypeInfoFactory.shortTypeInfo, null);
            case INTEGER:
                return new ExprNodeConstantDesc(TypeInfoFactory.intTypeInfo, null);
            case BIGINT:
                return new ExprNodeConstantDesc(TypeInfoFactory.longTypeInfo, null);
            case FLOAT:
            case REAL:
                return new ExprNodeConstantDesc(TypeInfoFactory.floatTypeInfo, null);
            case DOUBLE:
                return new ExprNodeConstantDesc(TypeInfoFactory.doubleTypeInfo, null);
            case DATE:
                return new ExprNodeConstantDesc(TypeInfoFactory.dateTypeInfo, null);
            case TIME:
            case TIMESTAMP:
                return new ExprNodeConstantDesc(TypeInfoFactory.timestampTypeInfo, null);
            case BINARY:
                return new ExprNodeConstantDesc(TypeInfoFactory.binaryTypeInfo, null);
            case DECIMAL:
                return new ExprNodeConstantDesc(
                        TypeInfoFactory.getDecimalTypeInfo(lType.getPrecision(), lType.getScale()), null);
            case VARCHAR:
            case CHAR:
                return new ExprNodeConstantDesc(TypeInfoFactory.stringTypeInfo, null);
            case INTERVAL_YEAR:
            case INTERVAL_MONTH:
            case INTERVAL_YEAR_MONTH:
                return new ExprNodeConstantDesc(TypeInfoFactory.intervalYearMonthTypeInfo, null);
            case INTERVAL_DAY:
            case INTERVAL_DAY_HOUR:
            case INTERVAL_DAY_MINUTE:
            case INTERVAL_DAY_SECOND:
            case INTERVAL_HOUR:
            case INTERVAL_HOUR_MINUTE:
            case INTERVAL_HOUR_SECOND:
            case INTERVAL_MINUTE:
            case INTERVAL_MINUTE_SECOND:
            case INTERVAL_SECOND:
                return new ExprNodeConstantDesc(TypeInfoFactory.intervalDayTimeTypeInfo, null);
            case OTHER:
            default:
                return new ExprNodeConstantDesc(TypeInfoFactory.voidTypeInfo, null);
            }
        } else {
            switch (literal.getType().getSqlTypeName()) {
            case BOOLEAN:
                return new ExprNodeConstantDesc(TypeInfoFactory.booleanTypeInfo,
                        Boolean.valueOf(RexLiteral.booleanValue(literal)));
            case TINYINT:
                return new ExprNodeConstantDesc(TypeInfoFactory.byteTypeInfo,
                        Byte.valueOf(((Number) literal.getValue3()).byteValue()));
            case SMALLINT:
                return new ExprNodeConstantDesc(TypeInfoFactory.shortTypeInfo,
                        Short.valueOf(((Number) literal.getValue3()).shortValue()));
            case INTEGER:
                return new ExprNodeConstantDesc(TypeInfoFactory.intTypeInfo,
                        Integer.valueOf(((Number) literal.getValue3()).intValue()));
            case BIGINT:
                return new ExprNodeConstantDesc(TypeInfoFactory.longTypeInfo,
                        Long.valueOf(((Number) literal.getValue3()).longValue()));
            case FLOAT:
            case REAL:
                return new ExprNodeConstantDesc(TypeInfoFactory.floatTypeInfo,
                        Float.valueOf(((Number) literal.getValue3()).floatValue()));
            case DOUBLE:
                return new ExprNodeConstantDesc(TypeInfoFactory.doubleTypeInfo,
                        Double.valueOf(((Number) literal.getValue3()).doubleValue()));
            case DATE: {
                final Calendar c = (Calendar) literal.getValue();
                return new ExprNodeConstantDesc(TypeInfoFactory.dateTypeInfo,
                        new java.sql.Date(c.getTimeInMillis()));
            }
            case TIME:
            case TIMESTAMP: {
                final Calendar c = (Calendar) literal.getValue();
                final DateTime dt = new DateTime(c.getTimeInMillis(), DateTimeZone.forTimeZone(c.getTimeZone()));
                return new ExprNodeConstantDesc(TypeInfoFactory.timestampTypeInfo, new Timestamp(dt.getMillis()));
            }
            case BINARY:
                return new ExprNodeConstantDesc(TypeInfoFactory.binaryTypeInfo, literal.getValue3());
            case DECIMAL:
                return new ExprNodeConstantDesc(
                        TypeInfoFactory.getDecimalTypeInfo(lType.getPrecision(), lType.getScale()),
                        HiveDecimal.create((BigDecimal) literal.getValue3()));
            case VARCHAR:
            case CHAR: {
                return new ExprNodeConstantDesc(TypeInfoFactory.stringTypeInfo, literal.getValue3());
            }
            case INTERVAL_YEAR:
            case INTERVAL_MONTH:
            case INTERVAL_YEAR_MONTH: {
                BigDecimal monthsBd = (BigDecimal) literal.getValue();
                return new ExprNodeConstantDesc(TypeInfoFactory.intervalYearMonthTypeInfo,
                        new HiveIntervalYearMonth(monthsBd.intValue()));
            }
            case INTERVAL_DAY:
            case INTERVAL_DAY_HOUR:
            case INTERVAL_DAY_MINUTE:
            case INTERVAL_DAY_SECOND:
            case INTERVAL_HOUR:
            case INTERVAL_HOUR_MINUTE:
            case INTERVAL_HOUR_SECOND:
            case INTERVAL_MINUTE:
            case INTERVAL_MINUTE_SECOND:
            case INTERVAL_SECOND: {
                BigDecimal millisBd = (BigDecimal) literal.getValue();
                // Calcite literal is in millis, we need to convert to seconds
                BigDecimal secsBd = millisBd.divide(BigDecimal.valueOf(1000));
                return new ExprNodeConstantDesc(TypeInfoFactory.intervalDayTimeTypeInfo,
                        new HiveIntervalDayTime(secsBd));
            }
            case OTHER:
            default:
                return new ExprNodeConstantDesc(TypeInfoFactory.voidTypeInfo, literal.getValue3());
            }
        }
    }

    @Override
    public ExprNodeDesc visitOver(RexOver over) {
        if (!deep) {
            return null;
        }

        final RexWindow window = over.getWindow();

        final WindowSpec windowSpec = new WindowSpec();
        final PartitioningSpec partitioningSpec = getPSpec(window);
        windowSpec.setPartitioning(partitioningSpec);
        final WindowFrameSpec windowFrameSpec = getWindowRange(window);
        windowSpec.setWindowFrame(windowFrameSpec);

        WindowFunctionSpec wfs = new WindowFunctionSpec();
        wfs.setWindowSpec(windowSpec);
        final Schema schema = new Schema(tabAlias, inputRowType.getFieldList());
        final ASTNode wUDAFAst = new ASTConverter.RexVisitor(schema).visitOver(over);
        wfs.setExpression(wUDAFAst);
        ASTNode nameNode = (ASTNode) wUDAFAst.getChild(0);
        wfs.setName(nameNode.getText());
        for (int i = 1; i < wUDAFAst.getChildCount() - 1; i++) {
            ASTNode child = (ASTNode) wUDAFAst.getChild(i);
            wfs.addArg(child);
        }
        if (wUDAFAst.getText().equals("TOK_FUNCTIONSTAR")) {
            wfs.setStar(true);
        }
        String columnAlias = getWindowColumnAlias();
        wfs.setAlias(columnAlias);

        this.windowFunctionSpecs.add(wfs);

        return new ExprNodeColumnDesc(TypeConverter.convert(over.getType()), columnAlias, tabAlias, false);
    }

    private PartitioningSpec getPSpec(RexWindow window) {
        PartitioningSpec partitioning = new PartitioningSpec();

        Schema schema = new Schema(tabAlias, inputRowType.getFieldList());

        if (window.partitionKeys != null && !window.partitionKeys.isEmpty()) {
            PartitionSpec pSpec = new PartitionSpec();
            for (RexNode pk : window.partitionKeys) {
                PartitionExpression exprSpec = new PartitionExpression();
                ASTNode astNode = pk.accept(new RexVisitor(schema));
                exprSpec.setExpression(astNode);
                pSpec.addExpression(exprSpec);
            }
            partitioning.setPartSpec(pSpec);
        }

        if (window.orderKeys != null && !window.orderKeys.isEmpty()) {
            OrderSpec oSpec = new OrderSpec();
            for (RexFieldCollation ok : window.orderKeys) {
                OrderExpression exprSpec = new OrderExpression();
                Order order = ok.getDirection() == RelFieldCollation.Direction.ASCENDING ? Order.ASC : Order.DESC;
                NullOrder nullOrder;
                if (ok.right.contains(SqlKind.NULLS_FIRST)) {
                    nullOrder = NullOrder.NULLS_FIRST;
                } else if (ok.right.contains(SqlKind.NULLS_LAST)) {
                    nullOrder = NullOrder.NULLS_LAST;
                } else {
                    // Default
                    nullOrder = ok.getDirection() == RelFieldCollation.Direction.ASCENDING ? NullOrder.NULLS_FIRST
                            : NullOrder.NULLS_LAST;
                }
                exprSpec.setOrder(order);
                exprSpec.setNullOrder(nullOrder);
                ASTNode astNode = ok.left.accept(new RexVisitor(schema));
                exprSpec.setExpression(astNode);
                oSpec.addExpression(exprSpec);
            }
            partitioning.setOrderSpec(oSpec);
        }

        return partitioning;
    }

    private WindowFrameSpec getWindowRange(RexWindow window) {
        // NOTE: in Hive AST Rows->Range(Physical) & Range -> Values (logical)
        BoundarySpec start = null;
        RexWindowBound ub = window.getUpperBound();
        if (ub != null) {
            start = getWindowBound(ub);
        }

        BoundarySpec end = null;
        RexWindowBound lb = window.getLowerBound();
        if (lb != null) {
            end = getWindowBound(lb);
        }

        return new WindowFrameSpec(window.isRows() ? WindowType.ROWS : WindowType.RANGE, start, end);
    }

    private BoundarySpec getWindowBound(RexWindowBound wb) {
        BoundarySpec boundarySpec;

        if (wb.isCurrentRow()) {
            boundarySpec = new BoundarySpec(Direction.CURRENT);
        } else {
            final Direction direction;
            final int amt;
            if (wb.isPreceding()) {
                direction = Direction.PRECEDING;
            } else {
                direction = Direction.FOLLOWING;
            }
            if (wb.isUnbounded()) {
                amt = BoundarySpec.UNBOUNDED_AMOUNT;
            } else {
                amt = RexLiteral.intValue(wb.getOffset());
            }

            boundarySpec = new BoundarySpec(direction, amt);
        }

        return boundarySpec;
    }

    private String getWindowColumnAlias() {
        return "$win$_col_" + (uniqueCounter++);
    }

}