Java tutorial
/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.prestosql.sql.planner; import com.google.common.collect.ImmutableList; import io.prestosql.sql.ExpressionUtils; import io.prestosql.sql.tree.AstVisitor; import io.prestosql.sql.tree.ComparisonExpression; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.Node; import io.prestosql.sql.tree.SymbolReference; import java.util.List; import java.util.Optional; import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Collections.singletonList; import static java.util.Comparator.comparing; import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; import static java.util.stream.Collectors.toMap; /** * Extracts sort expression to be used for creating {@link io.prestosql.operator.SortedPositionLinks} from join filter expression. * Currently this class can extract sort and search expressions from filter function conjuncts of shape: * <p> * {@code A.a < f(B.x, B.y, B.z)} or {@code f(B.x, B.y, B.z) < A.a} * <p> * where {@code a} is the build side symbol reference and {@code x,y,z} are probe * side symbol references. Any of inequality operators ({@code <,<=,>,>=}) can be used. * Same build side symbol need to be used in all conjuncts. */ public final class SortExpressionExtractor { /* TODO: This class could be extended to handle any expressions like: A.a * sin(A.b) / log(B.x) < cos(B.z) by transforming it to: f(A.a, A.b) < g(B.x, B.z) Where f(...) and g(...) would be some functions/expressions. That would allow us to perform binary search on arbitrary complex expressions by sorting position links according to the result of f(...) function. */ private SortExpressionExtractor() { } public static Optional<SortExpressionContext> extractSortExpression(Set<Symbol> buildSymbols, Expression filter) { List<Expression> filterConjuncts = ExpressionUtils.extractConjuncts(filter); SortExpressionVisitor visitor = new SortExpressionVisitor(buildSymbols); List<SortExpressionContext> sortExpressionCandidates = filterConjuncts.stream() .filter(DeterminismEvaluator::isDeterministic).map(visitor::process).filter(Optional::isPresent) .map(Optional::get) .collect( toMap(SortExpressionContext::getSortExpression, identity(), SortExpressionExtractor::merge)) .values().stream().collect(toImmutableList()); // For now heuristically pick sort expression which has most search expressions assigned to it. // TODO: make it cost based decision based on symbol statistics return sortExpressionCandidates.stream() .sorted(comparing(context -> -1 * context.getSearchExpressions().size())).findFirst(); } private static SortExpressionContext merge(SortExpressionContext left, SortExpressionContext right) { checkArgument(left.getSortExpression().equals(right.getSortExpression())); ImmutableList.Builder<Expression> searchExpressions = ImmutableList.builder(); searchExpressions.addAll(left.getSearchExpressions()); searchExpressions.addAll(right.getSearchExpressions()); return new SortExpressionContext(left.getSortExpression(), searchExpressions.build()); } private static class SortExpressionVisitor extends AstVisitor<Optional<SortExpressionContext>, Void> { private final Set<Symbol> buildSymbols; public SortExpressionVisitor(Set<Symbol> buildSymbols) { this.buildSymbols = buildSymbols; } @Override protected Optional<SortExpressionContext> visitExpression(Expression expression, Void context) { return Optional.empty(); } @Override protected Optional<SortExpressionContext> visitComparisonExpression(ComparisonExpression comparison, Void context) { switch (comparison.getOperator()) { case GREATER_THAN: case GREATER_THAN_OR_EQUAL: case LESS_THAN: case LESS_THAN_OR_EQUAL: Optional<SymbolReference> sortChannel = asBuildSymbolReference(buildSymbols, comparison.getRight()); boolean hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, comparison.getLeft()); if (!sortChannel.isPresent()) { sortChannel = asBuildSymbolReference(buildSymbols, comparison.getLeft()); hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, comparison.getRight()); } if (sortChannel.isPresent() && !hasBuildReferencesOnOtherSide) { return sortChannel.map(symbolReference -> new SortExpressionContext(symbolReference, singletonList(comparison))); } return Optional.empty(); default: return Optional.empty(); } } } private static Optional<SymbolReference> asBuildSymbolReference(Set<Symbol> buildLayout, Expression expression) { // Currently only we support only symbol as sort expression on build side if (expression instanceof SymbolReference) { SymbolReference symbolReference = (SymbolReference) expression; if (buildLayout.contains(new Symbol(symbolReference.getName()))) { return Optional.of(symbolReference); } } return Optional.empty(); } private static boolean hasBuildSymbolReference(Set<Symbol> buildSymbols, Expression expression) { return new BuildSymbolReferenceFinder(buildSymbols).process(expression); } private static class BuildSymbolReferenceFinder extends AstVisitor<Boolean, Void> { private final Set<String> buildSymbols; public BuildSymbolReferenceFinder(Set<Symbol> buildSymbols) { this.buildSymbols = requireNonNull(buildSymbols, "buildSymbols is null").stream().map(Symbol::getName) .collect(toImmutableSet()); } @Override protected Boolean visitNode(Node node, Void context) { for (Node child : node.getChildren()) { if (process(child, context)) { return true; } } return false; } @Override protected Boolean visitSymbolReference(SymbolReference symbolReference, Void context) { return buildSymbols.contains(symbolReference.getName()); } } }