com.facebook.presto.sql.planner.optimizations.MergeWindows.java Source code

Java tutorial

Introduction

Here is the source code for com.facebook.presto.sql.planner.optimizations.MergeWindows.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.DependencyExtractor;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.Multimap;

import java.util.Collection;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkState;

/**
 * Merge together the functions in WindowNodes that have identical WindowNode.Specifications.
 * For example:
 * <p>
 * OutputNode
 * `--...
 *    `--WindowNode(Specification: A, Functions: [sum(something)])
 *       `--WindowNode(Specification: B, Functions: [sum(something)])
 *          `--WindowNode(Specification: A, Functions: [avg(something)])
 *             `--...
 *
 * Will be transformed into
 * <p>
 * OutputNode
 * `--...
 *    `--WindowNode(Specification: B, Functions: [sum(something)])
 *       `--WindowNode(Specification: A, Functions: [avg(something), sum(something)])
 *          `--...
 *
 * This will NOT merge the functions in WindowNodes that have identical WindowNode.Specifications,
 * but have a node between them that is not a WindowNode.
 * In the following example, the functions in the WindowNodes with specification `A' will not be
 * merged into a single WindowNode.
 * <p>
 * OutputNode
 * `--...
 *    `--WindowNode(Specification: A, Functions: [sum(something)])
 *       `--WindowNode(Specification: B, Functions: [sum(something)])
 *          `-- ProjectNode(...)
 *             `--WindowNode(Specification: A, Functions: [avg(something)])
 *                `--...
 */
public class MergeWindows implements PlanOptimizer {
    @Override
    public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types,
            SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
        // ImmutableListMultimap preserves order of window nodes
        return SimplePlanRewriter.rewriteWith(new Rewriter(), plan, ImmutableListMultimap.of());
    }

    private static class Rewriter extends SimplePlanRewriter<Multimap<WindowNode.Specification, WindowNode>> {
        @Override
        protected PlanNode visitPlan(PlanNode node,
                RewriteContext<Multimap<WindowNode.Specification, WindowNode>> context) {
            PlanNode newNode = context.defaultRewrite(node, ImmutableListMultimap.of());
            return collapseWindowsWithinSpecification(context.get(), newNode);
        }

        @Override
        public PlanNode visitWindow(WindowNode windowNode,
                RewriteContext<Multimap<WindowNode.Specification, WindowNode>> context) {
            checkState(!windowNode.getHashSymbol().isPresent(),
                    "MergeWindows should be run before HashGenerationOptimizer");
            checkState(windowNode.getPrePartitionedInputs().isEmpty() && windowNode.getPreSortedOrderPrefix() == 0,
                    "MergeWindows should be run before AddExchanges");
            checkState(windowNode.getWindowFunctions().values().stream().distinct().count() == 1,
                    "Frames expected to be identical");

            for (WindowNode.Specification specification : context.get().keySet()) {
                Collection<WindowNode> nodes = context.get().get(specification);
                if (nodes.stream().anyMatch(node -> dependsOn(node, windowNode))) {
                    return collapseWindowsWithinSpecification(context.get(), context.rewrite(windowNode.getSource(),
                            ImmutableListMultimap.of(windowNode.getSpecification(), windowNode)));
                }
            }

            return context.rewrite(windowNode.getSource(),
                    ImmutableListMultimap.<WindowNode.Specification, WindowNode>builder()
                            .put(windowNode.getSpecification(), windowNode) // Add the current window first so that it gets precedence in iteration order
                            .putAll(context.get()).build());
        }

        private static PlanNode collapseWindowsWithinSpecification(
                Multimap<WindowNode.Specification, WindowNode> windowsMap, PlanNode sourceNode) {
            for (WindowNode.Specification specification : windowsMap.keySet()) {
                Collection<WindowNode> windows = windowsMap.get(specification);
                sourceNode = collapseWindows(sourceNode, specification, windows);
            }
            return sourceNode;
        }

        private static WindowNode collapseWindows(PlanNode source, WindowNode.Specification specification,
                Collection<WindowNode> windows) {
            WindowNode canonical = windows.iterator().next();
            return new WindowNode(canonical.getId(), source, specification,
                    windows.stream().map(WindowNode::getWindowFunctions).flatMap(map -> map.entrySet().stream())
                            .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)),
                    canonical.getHashSymbol(), canonical.getPrePartitionedInputs(),
                    canonical.getPreSortedOrderPrefix());
        }

        private static boolean dependsOn(WindowNode parent, WindowNode child) {
            Set<Symbol> childOutputs = child.getCreatedSymbols();

            Stream<Symbol> arguments = parent.getWindowFunctions().values().stream()
                    .map(WindowNode.Function::getFunctionCall)
                    .flatMap(functionCall -> functionCall.getArguments().stream())
                    .map(DependencyExtractor::extractUnique).flatMap(Collection::stream);

            return parent.getPartitionBy().stream().anyMatch(childOutputs::contains)
                    || parent.getOrderBy().stream().anyMatch(childOutputs::contains)
                    || arguments.anyMatch(childOutputs::contains);
        }
    }
}