com.facebook.presto.operator.HashAggregationOperator.java Source code

Java tutorial

Introduction

Here is the source code for com.facebook.presto.operator.HashAggregationOperator.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator;

import com.facebook.presto.operator.aggregation.AccumulatorFactory;
import com.facebook.presto.operator.aggregation.GroupedAccumulator;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.plan.AggregationNode.Step;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Ints;
import io.airlift.units.DataSize;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;

import static com.facebook.presto.operator.GroupByHash.createGroupByHash;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

public class HashAggregationOperator implements Operator {
    public static class HashAggregationOperatorFactory implements OperatorFactory {
        private final int operatorId;
        private final Optional<Integer> maskChannel;
        private final List<Type> groupByTypes;
        private final List<Integer> groupByChannels;
        private final Step step;
        private final List<AccumulatorFactory> accumulatorFactories;
        private final Optional<Integer> hashChannel;

        private final int expectedGroups;
        private final List<Type> types;
        private boolean closed;
        private final long maxPartialMemory;

        public HashAggregationOperatorFactory(int operatorId, List<? extends Type> groupByTypes,
                List<Integer> groupByChannels, Step step, List<AccumulatorFactory> accumulatorFactories,
                Optional<Integer> maskChannel, Optional<Integer> hashChannel, int expectedGroups,
                DataSize maxPartialMemory) {
            this.operatorId = operatorId;
            this.maskChannel = requireNonNull(maskChannel, "maskChannel is null");
            this.hashChannel = requireNonNull(hashChannel, "hashChannel is null");
            this.groupByTypes = ImmutableList.copyOf(groupByTypes);
            this.groupByChannels = ImmutableList.copyOf(groupByChannels);
            this.step = step;
            this.accumulatorFactories = ImmutableList.copyOf(accumulatorFactories);
            this.expectedGroups = expectedGroups;
            this.maxPartialMemory = requireNonNull(maxPartialMemory, "maxPartialMemory is null").toBytes();

            this.types = toTypes(groupByTypes, step, accumulatorFactories, hashChannel);
        }

        @Override
        public List<Type> getTypes() {
            return types;
        }

        @Override
        public Operator createOperator(DriverContext driverContext) {
            checkState(!closed, "Factory is already closed");

            OperatorContext operatorContext;
            if (step.isOutputPartial()) {
                operatorContext = driverContext.addOperatorContext(operatorId,
                        HashAggregationOperator.class.getSimpleName(), maxPartialMemory);
            } else {
                operatorContext = driverContext.addOperatorContext(operatorId,
                        HashAggregationOperator.class.getSimpleName());
            }
            HashAggregationOperator hashAggregationOperator = new HashAggregationOperator(operatorContext,
                    groupByTypes, groupByChannels, step, accumulatorFactories, maskChannel, hashChannel,
                    expectedGroups);
            return hashAggregationOperator;
        }

        @Override
        public void close() {
            closed = true;
        }
    }

    private final OperatorContext operatorContext;
    private final List<Type> groupByTypes;
    private final List<Integer> groupByChannels;
    private final Step step;
    private final List<AccumulatorFactory> accumulatorFactories;
    private final Optional<Integer> maskChannel;
    private final Optional<Integer> hashChannel;
    private final int expectedGroups;

    private final List<Type> types;

    private GroupByHashAggregationBuilder aggregationBuilder;
    private Iterator<Page> outputIterator;
    private boolean finishing;

    public HashAggregationOperator(OperatorContext operatorContext, List<Type> groupByTypes,
            List<Integer> groupByChannels, Step step, List<AccumulatorFactory> accumulatorFactories,
            Optional<Integer> maskChannel, Optional<Integer> hashChannel, int expectedGroups) {
        this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
        requireNonNull(step, "step is null");
        requireNonNull(accumulatorFactories, "accumulatorFactories is null");
        requireNonNull(operatorContext, "operatorContext is null");

        this.groupByTypes = ImmutableList.copyOf(groupByTypes);
        this.groupByChannels = ImmutableList.copyOf(groupByChannels);
        this.accumulatorFactories = ImmutableList.copyOf(accumulatorFactories);
        this.maskChannel = requireNonNull(maskChannel, "maskChannel is null");
        this.hashChannel = requireNonNull(hashChannel, "hashChannel is null");
        this.step = step;
        this.expectedGroups = expectedGroups;
        this.types = toTypes(groupByTypes, step, accumulatorFactories, hashChannel);
    }

    @Override
    public OperatorContext getOperatorContext() {
        return operatorContext;
    }

    @Override
    public List<Type> getTypes() {
        return types;
    }

    @Override
    public void finish() {
        finishing = true;
    }

    @Override
    public boolean isFinished() {
        return finishing && aggregationBuilder == null && (outputIterator == null || !outputIterator.hasNext());
    }

    @Override
    public boolean needsInput() {
        return !finishing && outputIterator == null && (aggregationBuilder == null || !aggregationBuilder.isFull());
    }

    @Override
    public void addInput(Page page) {
        checkState(!finishing, "Operator is already finishing");
        requireNonNull(page, "page is null");
        if (aggregationBuilder == null) {
            aggregationBuilder = new GroupByHashAggregationBuilder(accumulatorFactories, step, expectedGroups,
                    groupByTypes, groupByChannels, maskChannel, hashChannel, operatorContext);

            // assume initial aggregationBuilder is not full
        } else {
            checkState(!aggregationBuilder.isFull(), "Aggregation buffer is full");
        }
        aggregationBuilder.processPage(page);
    }

    @Override
    public Page getOutput() {
        if (outputIterator == null || !outputIterator.hasNext()) {
            // current output iterator is done
            outputIterator = null;

            // no data
            if (aggregationBuilder == null) {
                return null;
            }

            // only flush if we are finishing or the aggregation builder is full
            if (!finishing && !aggregationBuilder.isFull()) {
                return null;
            }

            outputIterator = aggregationBuilder.build();
            aggregationBuilder = null;

            if (!outputIterator.hasNext()) {
                // current output iterator is done
                outputIterator = null;
                return null;
            }
        }

        return outputIterator.next();
    }

    private static List<Type> toTypes(List<? extends Type> groupByType, Step step,
            List<AccumulatorFactory> factories, Optional<Integer> hashChannel) {
        ImmutableList.Builder<Type> types = ImmutableList.builder();
        types.addAll(groupByType);
        if (hashChannel.isPresent()) {
            types.add(BIGINT);
        }
        for (AccumulatorFactory factory : factories) {
            types.add(new Aggregator(factory, step).getType());
        }
        return types.build();
    }

    private static class GroupByHashAggregationBuilder {
        private final GroupByHash groupByHash;
        private final List<Aggregator> aggregators;
        private final OperatorContext operatorContext;
        private final boolean partial;

        private GroupByHashAggregationBuilder(List<AccumulatorFactory> accumulatorFactories, Step step,
                int expectedGroups, List<Type> groupByTypes, List<Integer> groupByChannels,
                Optional<Integer> maskChannel, Optional<Integer> hashChannel, OperatorContext operatorContext) {
            this.groupByHash = createGroupByHash(groupByTypes, Ints.toArray(groupByChannels), maskChannel,
                    hashChannel, expectedGroups);
            this.operatorContext = operatorContext;
            this.partial = step.isOutputPartial();

            // wrapper each function with an aggregator
            ImmutableList.Builder<Aggregator> builder = ImmutableList.builder();
            requireNonNull(accumulatorFactories, "accumulatorFactories is null");
            for (int i = 0; i < accumulatorFactories.size(); i++) {
                AccumulatorFactory accumulatorFactory = accumulatorFactories.get(i);
                builder.add(new Aggregator(accumulatorFactory, step));
            }
            aggregators = builder.build();
        }

        private void processPage(Page page) {
            if (aggregators.isEmpty()) {
                groupByHash.addPage(page);
                return;
            }

            GroupByIdBlock groupIds = groupByHash.getGroupIds(page);

            for (Aggregator aggregator : aggregators) {
                aggregator.processPage(groupIds, page);
            }
        }

        public boolean isFull() {
            long memorySize = groupByHash.getEstimatedSize();
            for (Aggregator aggregator : aggregators) {
                memorySize += aggregator.getEstimatedSize();
            }
            memorySize -= operatorContext.getOperatorPreAllocatedMemory().toBytes();
            if (memorySize < 0) {
                memorySize = 0;
            }
            if (partial) {
                return !operatorContext.trySetMemoryReservation(memorySize);
            } else {
                operatorContext.setMemoryReservation(memorySize);
                return false;
            }
        }

        public Iterator<Page> build() {
            List<Type> types = new ArrayList<>(groupByHash.getTypes());
            for (Aggregator aggregator : aggregators) {
                types.add(aggregator.getType());
            }

            final PageBuilder pageBuilder = new PageBuilder(types);
            return new AbstractIterator<Page>() {
                private final int groupCount = groupByHash.getGroupCount();
                private int groupId;

                @Override
                protected Page computeNext() {
                    if (groupId >= groupCount) {
                        return endOfData();
                    }

                    pageBuilder.reset();

                    List<Type> types = groupByHash.getTypes();
                    while (!pageBuilder.isFull() && groupId < groupCount) {
                        groupByHash.appendValuesTo(groupId, pageBuilder, 0);

                        pageBuilder.declarePosition();
                        for (int i = 0; i < aggregators.size(); i++) {
                            Aggregator aggregator = aggregators.get(i);
                            BlockBuilder output = pageBuilder.getBlockBuilder(types.size() + i);
                            aggregator.evaluate(groupId, output);
                        }

                        groupId++;
                    }

                    return pageBuilder.build();
                }
            };
        }
    }

    private static class Aggregator {
        private final GroupedAccumulator aggregation;
        private final Step step;
        private final int intermediateChannel;

        private Aggregator(AccumulatorFactory accumulatorFactory, Step step) {
            if (step.isInputRaw()) {
                intermediateChannel = -1;
                aggregation = accumulatorFactory.createGroupedAccumulator();
            } else {
                checkArgument(accumulatorFactory.getInputChannels().size() == 1,
                        "expected 1 input channel for intermediate aggregation");
                intermediateChannel = accumulatorFactory.getInputChannels().get(0);
                aggregation = accumulatorFactory.createGroupedIntermediateAccumulator();
            }
            this.step = step;
        }

        public long getEstimatedSize() {
            return aggregation.getEstimatedSize();
        }

        public Type getType() {
            if (step.isOutputPartial()) {
                return aggregation.getIntermediateType();
            } else {
                return aggregation.getFinalType();
            }
        }

        public void processPage(GroupByIdBlock groupIds, Page page) {
            if (step.isInputRaw()) {
                aggregation.addInput(groupIds, page);
            } else {
                aggregation.addIntermediate(groupIds, page.getBlock(intermediateChannel));
            }
        }

        public void evaluate(int groupId, BlockBuilder output) {
            if (step.isOutputPartial()) {
                aggregation.evaluateIntermediate(groupId, output);
            } else {
                aggregation.evaluateFinal(groupId, output);
            }
        }
    }
}