com.facebook.presto.operator.aggregation.AggregationCompiler.java Source code

Java tutorial

Introduction

Here is the source code for com.facebook.presto.operator.aggregation.AggregationCompiler.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.byteCode.DynamicClassLoader;
import com.facebook.presto.operator.aggregation.state.AccumulatorState;
import com.facebook.presto.operator.aggregation.state.AccumulatorStateFactory;
import com.facebook.presto.operator.aggregation.state.AccumulatorStateSerializer;
import com.facebook.presto.operator.aggregation.state.StateCompiler;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.facebook.presto.type.SqlType;
import com.facebook.presto.type.TypeRegistry;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import javax.annotation.Nullable;

import java.lang.annotation.Annotation;
import java.lang.invoke.MethodHandle;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.List;
import java.util.Set;

import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.fromAnnotations;
import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName;
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.util.ImmutableCollectors.toImmutableList;
import static com.google.common.base.Preconditions.checkArgument;
import static java.lang.invoke.MethodHandles.lookup;
import static java.util.Objects.requireNonNull;

public class AggregationCompiler {
    private final TypeManager typeManager;

    public AggregationCompiler() {
        this(new TypeRegistry());
    }

    public AggregationCompiler(TypeManager typeManager) {
        this.typeManager = requireNonNull(typeManager, "typeManager is null");
    }

    private static List<Method> findPublicStaticMethodsWithAnnotation(Class<?> clazz, Class<?> annotationClass) {
        ImmutableList.Builder<Method> methods = ImmutableList.builder();
        for (Method method : clazz.getMethods()) {
            for (Annotation annotation : method.getAnnotations()) {
                if (annotationClass.isInstance(annotation)) {
                    checkArgument(
                            Modifier.isStatic(method.getModifiers()) && Modifier.isPublic(method.getModifiers()),
                            "%s annotated with %s must be static and public", method.getName(),
                            annotationClass.getSimpleName());
                    methods.add(method);
                }
            }
        }
        return methods.build();
    }

    public InternalAggregationFunction generateAggregationFunction(Class<?> clazz) {
        List<InternalAggregationFunction> aggregations = generateAggregationFunctions(clazz);
        checkArgument(aggregations.size() == 1, "More than one aggregation function found");
        return aggregations.get(0);
    }

    public InternalAggregationFunction generateAggregationFunction(Class<?> clazz, Type returnType,
            List<Type> argumentTypes) {
        requireNonNull(returnType, "returnType is null");
        requireNonNull(argumentTypes, "argumentTypes is null");
        for (InternalAggregationFunction aggregation : generateAggregationFunctions(clazz)) {
            if (aggregation.getFinalType().equals(returnType)
                    && aggregation.getParameterTypes().equals(argumentTypes)) {
                return aggregation;
            }
        }
        throw new IllegalArgumentException(
                String.format("No method with return type %s and arguments %s", returnType, argumentTypes));
    }

    public List<InternalAggregationFunction> generateAggregationFunctions(Class<?> clazz) {
        AggregationFunction aggregationAnnotation = clazz.getAnnotation(AggregationFunction.class);
        requireNonNull(aggregationAnnotation, "aggregationAnnotation is null");

        DynamicClassLoader classLoader = new DynamicClassLoader(clazz.getClassLoader());

        ImmutableList.Builder<InternalAggregationFunction> builder = ImmutableList.builder();
        for (Class<?> stateClass : getStateClasses(clazz)) {
            AccumulatorStateSerializer<?> stateSerializer = new StateCompiler().generateStateSerializer(stateClass,
                    classLoader);
            Type intermediateType = stateSerializer.getSerializedType();
            Method intermediateInputFunction = getIntermediateInputFunction(clazz, stateClass);
            Method combineFunction = getCombineFunction(clazz, stateClass);
            AccumulatorStateFactory<?> stateFactory = new StateCompiler().generateStateFactory(stateClass,
                    classLoader);

            for (Method outputFunction : getOutputFunctions(clazz, stateClass)) {
                for (Method inputFunction : getInputFunctions(clazz, stateClass)) {
                    for (String name : getNames(outputFunction, aggregationAnnotation)) {
                        List<Type> inputTypes = getInputTypes(inputFunction);
                        Type outputType = AggregationUtils.getOutputType(outputFunction, stateSerializer,
                                typeManager);

                        AggregationMetadata metadata;
                        try {
                            MethodHandle inputHandle = lookup().unreflect(inputFunction);
                            MethodHandle intermediateInputHandle = intermediateInputFunction == null ? null
                                    : lookup().unreflect(intermediateInputFunction);
                            MethodHandle combineHandle = combineFunction == null ? null
                                    : lookup().unreflect(combineFunction);
                            MethodHandle outputHandle = outputFunction == null ? null
                                    : lookup().unreflect(outputFunction);
                            metadata = new AggregationMetadata(
                                    generateAggregationName(name, outputType, inputTypes),
                                    getParameterMetadata(inputFunction, aggregationAnnotation.approximate()),
                                    inputHandle, getParameterMetadata(intermediateInputFunction, false),
                                    intermediateInputHandle, combineHandle, outputHandle, stateClass,
                                    stateSerializer, stateFactory, outputType, aggregationAnnotation.approximate());
                        } catch (IllegalAccessException e) {
                            throw Throwables.propagate(e);
                        }

                        AccumulatorFactoryBinder factory = new LazyAccumulatorFactoryBinder(metadata, classLoader);
                        builder.add(new InternalAggregationFunction(name, inputTypes, intermediateType, outputType,
                                aggregationAnnotation.decomposable(), aggregationAnnotation.approximate(),
                                factory));
                    }
                }
            }
        }

        return builder.build();
    }

    private List<ParameterMetadata> getParameterMetadata(@Nullable Method method, boolean sampleWeightAllowed) {
        if (method == null) {
            return null;
        }

        ImmutableList.Builder<ParameterMetadata> builder = ImmutableList.builder();
        builder.add(new ParameterMetadata(STATE));

        Annotation[][] annotations = method.getParameterAnnotations();
        // Start at 1 because 0 is the STATE
        for (int i = 1; i < annotations.length; i++) {
            builder.add(fromAnnotations(annotations[i], method.getDeclaringClass() + "." + method.getName(),
                    typeManager, sampleWeightAllowed));
        }
        return builder.build();
    }

    private static List<String> getNames(@Nullable Method outputFunction,
            AggregationFunction aggregationAnnotation) {
        List<String> defaultNames = ImmutableList.<String>builder().add(aggregationAnnotation.value())
                .addAll(Arrays.asList(aggregationAnnotation.alias())).build();

        if (outputFunction == null) {
            return defaultNames;
        }

        AggregationFunction annotation = outputFunction.getAnnotation(AggregationFunction.class);
        if (annotation == null) {
            return defaultNames;
        } else {
            return ImmutableList.<String>builder().add(annotation.value()).addAll(Arrays.asList(annotation.alias()))
                    .build();
        }
    }

    private static Method getIntermediateInputFunction(Class<?> clazz, Class<?> stateClass) {
        for (Method method : findPublicStaticMethodsWithAnnotation(clazz, IntermediateInputFunction.class)) {
            if (method.getParameterTypes()[0] == stateClass) {
                return method;
            }
        }
        return null;
    }

    private static Method getCombineFunction(Class<?> clazz, Class<?> stateClass) {
        for (Method method : findPublicStaticMethodsWithAnnotation(clazz, CombineFunction.class)) {
            if (method.getParameterTypes()[0] == stateClass) {
                return method;
            }
        }
        return null;
    }

    private static List<Method> getOutputFunctions(Class<?> clazz, Class<?> stateClass) {
        // Only include methods that match this state class
        List<Method> methods = findPublicStaticMethodsWithAnnotation(clazz, OutputFunction.class).stream()
                .filter(method -> method.getParameterTypes()[0] == stateClass).collect(toImmutableList());

        checkArgument(!methods.isEmpty(), "Aggregation has no output functions");
        return methods;
    }

    private static List<Method> getInputFunctions(Class<?> clazz, Class<?> stateClass) {
        // Only include methods that match this state class
        List<Method> inputFunctions = findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class).stream()
                .filter(method -> method.getParameterTypes()[0] == stateClass).collect(toImmutableList());

        checkArgument(!inputFunctions.isEmpty(), "Aggregation has no input functions");
        return inputFunctions;
    }

    private List<Type> getInputTypes(Method inputFunction) {
        ImmutableList.Builder<Type> builder = ImmutableList.builder();
        Annotation[][] parameterAnnotations = inputFunction.getParameterAnnotations();
        for (Annotation[] annotations : parameterAnnotations) {
            for (Annotation annotation : annotations) {
                if (annotation instanceof SqlType) {
                    String typeName = ((SqlType) annotation).value();
                    builder.add(typeManager.getType(parseTypeSignature(typeName)));
                }
            }
        }

        return builder.build();
    }

    private static Set<Class<?>> getStateClasses(Class<?> clazz) {
        ImmutableSet.Builder<Class<?>> builder = ImmutableSet.builder();
        for (Method inputFunction : findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class)) {
            checkArgument(inputFunction.getParameterTypes().length > 0, "Input function has no parameters");
            Class<?> stateClass = inputFunction.getParameterTypes()[0];
            checkArgument(AccumulatorState.class.isAssignableFrom(stateClass),
                    "stateClass is not a subclass of AccumulatorState");
            builder.add(stateClass);
        }
        ImmutableSet<Class<?>> stateClasses = builder.build();
        checkArgument(!stateClasses.isEmpty(), "No input functions found");

        return stateClasses;
    }
}