com.facebook.presto.operator.MergeHashSort.java Source code

Java tutorial

Introduction

Here is the source code for com.facebook.presto.operator.MergeHashSort.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator;

import com.facebook.presto.memory.AggregatedMemoryContext;
import com.facebook.presto.memory.LocalMemoryContext;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.type.Type;
import com.google.common.collect.Iterators;

import java.io.Closeable;
import java.util.Iterator;
import java.util.List;

import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;

/**
 * This class performs merge of previously hash sorted pages streams.
 * <p>
 * Positions are compared using their hash value. It is possible
 * that two distinct values to have same hash value, thus returned
 * stream of Pages can have interleaved positions with same hash value.
 */
public class MergeHashSort implements Closeable {
    private final AggregatedMemoryContext memoryContext;

    public MergeHashSort(AggregatedMemoryContext memoryContext) {
        this.memoryContext = memoryContext;
    }

    /**
     * Rows with same hash value are guaranteed to be in the same result page.
     */
    public Iterator<Page> merge(List<Type> keyTypes, List<Type> allTypes, List<Iterator<Page>> channels) {
        List<Iterator<PagePosition>> channelIterators = channels.stream()
                .map(channel -> new SingleChannelPagePositions(channel, memoryContext.newLocalMemoryContext()))
                .collect(toList());

        int[] hashChannels = new int[keyTypes.size()];
        for (int i = 0; i < keyTypes.size(); i++) {
            hashChannels[i] = i;
        }
        HashGenerator hashGenerator = new InterpretedHashGenerator(keyTypes, hashChannels);

        return new PageRewriteIterator(hashGenerator, allTypes,
                Iterators.mergeSorted(channelIterators,
                        (PagePosition left, PagePosition right) -> comparePages(hashGenerator, left, right)),
                memoryContext.newLocalMemoryContext());
    }

    private static int comparePages(HashGenerator hashGenerator, PagePosition left, PagePosition right) {
        if (left.isPositionOutOfPage() && right.isPositionOutOfPage()) {
            return 0;
        }
        if (left.isPositionOutOfPage()) {
            return -1;
        }
        if (right.isPositionOutOfPage()) {
            return 1;
        }

        long leftHash = hashGenerator.hashPosition(left.getPosition(), left.getPage());
        long rightHash = hashGenerator.hashPosition(right.getPosition(), right.getPage());

        return Long.compare(leftHash, rightHash);
    }

    @Override
    public void close() {
        memoryContext.close();
    }

    static class PagePosition {
        private final Page page;
        private final int position;

        public PagePosition(Page page, int position) {
            this.page = requireNonNull(page, "page is null");
            this.position = requireNonNull(position, "position is null");
        }

        public Page getPage() {
            return page;
        }

        public int getPosition() {
            return position;
        }

        public boolean isPositionOutOfPage() {
            return position >= page.getPositionCount();
        }
    }

    public interface PagePositions extends Iterator<PagePosition> {
    }

    public static class SingleChannelPagePositions implements PagePositions {
        private final Iterator<Page> channel;
        private final LocalMemoryContext memoryContext;
        private PagePosition current;

        public SingleChannelPagePositions(Iterator<Page> channel, LocalMemoryContext memoryContext) {
            this.channel = requireNonNull(channel, "channel is null");
            this.memoryContext = memoryContext;
        }

        @Override
        public boolean hasNext() {
            return channel.hasNext()
                    || (current != null && current.getPosition() + 1 < current.getPage().getPositionCount());
        }

        @Override
        public PagePosition next() {
            if (current == null || current.getPosition() + 1 >= current.getPage().getPositionCount()) {
                current = new PagePosition(channel.next(), 0);
                memoryContext.setBytes(current.getPage().getRetainedSizeInBytes());
            } else {
                current = new PagePosition(current.getPage(), current.getPosition() + 1);
            }
            return current;
        }
    }

    /**
     * This class rewrites iterator over PagePosition to iterator over Pages.
     */
    public static class PageRewriteIterator implements Iterator<Page> {
        private final List<Type> allTypes;
        private final Iterator<PagePosition> pagePositions;
        private final HashGenerator hashGenerator;
        private final PageBuilder builder;
        private final LocalMemoryContext memoryContext;
        private PagePosition currentPage = null;

        public PageRewriteIterator(HashGenerator hashGenerator, List<Type> allTypes,
                Iterator<PagePosition> pagePositions, LocalMemoryContext memoryContext) {
            this.hashGenerator = hashGenerator;
            this.allTypes = allTypes;
            this.pagePositions = pagePositions;
            this.builder = new PageBuilder(allTypes);
            this.memoryContext = memoryContext;
        }

        @Override
        public boolean hasNext() {
            return currentPage != null || pagePositions.hasNext();
        }

        @Override
        public Page next() {
            builder.reset();

            if (currentPage == null) {
                currentPage = pagePositions.next();
            }

            PagePosition previousPage = currentPage;

            while (comparePages(hashGenerator, currentPage, previousPage) == 0 || !builder.isFull()) {
                if (!currentPage.isPositionOutOfPage()) {
                    builder.declarePosition();
                    for (int column = 0; column < allTypes.size(); column++) {
                        Type type = allTypes.get(column);
                        type.appendTo(currentPage.getPage().getBlock(column), currentPage.getPosition(),
                                builder.getBlockBuilder(column));
                    }
                    previousPage = currentPage;
                    memoryContext.setBytes(builder.getRetainedSizeInBytes());
                }
                if (pagePositions.hasNext()) {
                    currentPage = pagePositions.next();
                } else {
                    currentPage = null;
                    break;
                }
            }
            return builder.build();
        }
    }
}