org.apache.beam.runners.spark.translation.EvaluationContext.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.beam.runners.spark.translation.EvaluationContext.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.beam.runners.spark.translation;

import static com.google.common.base.Preconditions.checkArgument;

import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.beam.runners.spark.EvaluationResult;
import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.sdk.AggregatorRetrievalException;
import org.apache.beam.sdk.AggregatorValues;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.metrics.MetricResults;
import org.apache.beam.sdk.transforms.Aggregator;
import org.apache.beam.sdk.transforms.AppliedPTransform;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.PValue;
import org.apache.spark.api.java.JavaRDDLike;
import org.apache.spark.api.java.JavaSparkContext;
import org.joda.time.Duration;

/**
 * Evaluation context allows us to define how pipeline instructions.
 */
public class EvaluationContext implements EvaluationResult {
    private final JavaSparkContext jsc;
    private final Pipeline pipeline;
    private final SparkRuntimeContext runtime;
    private final Map<PValue, RDDHolder<?>> pcollections = new LinkedHashMap<>();
    private final Set<RDDHolder<?>> leafRdds = new LinkedHashSet<>();
    private final Set<PValue> multireads = new LinkedHashSet<>();
    private final Map<PValue, Object> pobjects = new LinkedHashMap<>();
    private final Map<PValue, Iterable<? extends WindowedValue<?>>> pview = new LinkedHashMap<>();
    protected AppliedPTransform<?, ?, ?> currentTransform;

    public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline) {
        this.jsc = jsc;
        this.pipeline = pipeline;
        this.runtime = new SparkRuntimeContext(pipeline, jsc);
    }

    /**
     * Holds an RDD or values for deferred conversion to an RDD if needed. PCollections are
     * sometimes created from a collection of objects (using RDD parallelize) and then
     * only used to create View objects; in which case they do not need to be
     * converted to bytes since they are not transferred across the network until they are
     * broadcast.
     */
    private class RDDHolder<T> {

        private Iterable<WindowedValue<T>> windowedValues;
        private Coder<T> coder;
        private JavaRDDLike<WindowedValue<T>, ?> rdd;

        RDDHolder(Iterable<T> values, Coder<T> coder) {
            this.windowedValues = Iterables.transform(values, WindowingHelpers.<T>windowValueFunction());
            this.coder = coder;
        }

        RDDHolder(JavaRDDLike<WindowedValue<T>, ?> rdd) {
            this.rdd = rdd;
        }

        JavaRDDLike<WindowedValue<T>, ?> getRDD() {
            if (rdd == null) {
                WindowedValue.ValueOnlyWindowedValueCoder<T> windowCoder = WindowedValue.getValueOnlyCoder(coder);
                rdd = jsc.parallelize(CoderHelpers.toByteArrays(windowedValues, windowCoder))
                        .map(CoderHelpers.fromByteFunction(windowCoder));
            }
            return rdd;
        }

        Iterable<WindowedValue<T>> getValues(PCollection<T> pcollection) {
            if (windowedValues == null) {
                WindowFn<?, ?> windowFn = pcollection.getWindowingStrategy().getWindowFn();
                Coder<? extends BoundedWindow> windowCoder = windowFn.windowCoder();
                final WindowedValue.WindowedValueCoder<T> windowedValueCoder;
                if (windowFn instanceof GlobalWindows) {
                    windowedValueCoder = WindowedValue.ValueOnlyWindowedValueCoder.of(pcollection.getCoder());
                } else {
                    windowedValueCoder = WindowedValue.FullWindowedValueCoder.of(pcollection.getCoder(),
                            windowCoder);
                }
                JavaRDDLike<byte[], ?> bytesRDD = rdd.map(CoderHelpers.toByteFunction(windowedValueCoder));
                List<byte[]> clientBytes = bytesRDD.collect();
                windowedValues = Iterables.transform(clientBytes, new Function<byte[], WindowedValue<T>>() {
                    @Override
                    public WindowedValue<T> apply(byte[] bytes) {
                        return CoderHelpers.fromByteArray(bytes, windowedValueCoder);
                    }
                });
            }
            return windowedValues;
        }
    }

    protected JavaSparkContext getSparkContext() {
        return jsc;
    }

    public Pipeline getPipeline() {
        return pipeline;
    }

    protected SparkRuntimeContext getRuntimeContext() {
        return runtime;
    }

    public void setCurrentTransform(AppliedPTransform<?, ?, ?> transform) {
        this.currentTransform = transform;
    }

    protected AppliedPTransform<?, ?, ?> getCurrentTransform() {
        return currentTransform;
    }

    protected <T extends PInput> T getInput(PTransform<T, ?> transform) {
        checkArgument(currentTransform != null && currentTransform.getTransform() == transform,
                "can only be called with current transform");
        @SuppressWarnings("unchecked")
        T input = (T) currentTransform.getInput();
        return input;
    }

    protected <T extends POutput> T getOutput(PTransform<?, T> transform) {
        checkArgument(currentTransform != null && currentTransform.getTransform() == transform,
                "can only be called with current transform");
        @SuppressWarnings("unchecked")
        T output = (T) currentTransform.getOutput();
        return output;
    }

    protected <T> void setOutputRDD(PTransform<?, ?> transform, JavaRDDLike<WindowedValue<T>, ?> rdd) {
        setRDD((PValue) getOutput(transform), rdd);
    }

    protected <T> void setOutputRDDFromValues(PTransform<?, ?> transform, Iterable<T> values, Coder<T> coder) {
        pcollections.put((PValue) getOutput(transform), new RDDHolder<>(values, coder));
    }

    public void setPView(PValue view, Iterable<? extends WindowedValue<?>> value) {
        pview.put(view, value);
    }

    protected boolean hasOutputRDD(PTransform<? extends PInput, ?> transform) {
        PValue pvalue = (PValue) getOutput(transform);
        return pcollections.containsKey(pvalue);
    }

    public JavaRDDLike<?, ?> getRDD(PValue pvalue) {
        RDDHolder<?> rddHolder = pcollections.get(pvalue);
        JavaRDDLike<?, ?> rdd = rddHolder.getRDD();
        leafRdds.remove(rddHolder);
        if (multireads.contains(pvalue)) {
            // Ensure the RDD is marked as cached
            rdd.rdd().cache();
        } else {
            multireads.add(pvalue);
        }
        return rdd;
    }

    protected <T> void setRDD(PValue pvalue, JavaRDDLike<WindowedValue<T>, ?> rdd) {
        try {
            rdd.rdd().setName(pvalue.getName());
        } catch (IllegalStateException e) {
            // name not set, ignore
        }
        RDDHolder<T> rddHolder = new RDDHolder<>(rdd);
        pcollections.put(pvalue, rddHolder);
        leafRdds.add(rddHolder);
    }

    protected JavaRDDLike<?, ?> getInputRDD(PTransform<? extends PInput, ?> transform) {
        return getRDD((PValue) getInput(transform));
    }

    <T> Iterable<? extends WindowedValue<?>> getPCollectionView(PCollectionView<T> view) {
        return pview.get(view);
    }

    /**
     * Computes the outputs for all RDDs that are leaves in the DAG and do not have any
     * actions (like saving to a file) registered on them (i.e. they are performed for side
     * effects).
     */
    public void computeOutputs() {
        for (RDDHolder<?> rddHolder : leafRdds) {
            JavaRDDLike<?, ?> rdd = rddHolder.getRDD();
            rdd.rdd().cache(); // cache so that any subsequent get() is cheap
            rdd.count(); // force the RDD to be computed
        }
    }

    @Override
    public <T> T get(PValue value) {
        if (pobjects.containsKey(value)) {
            @SuppressWarnings("unchecked")
            T result = (T) pobjects.get(value);
            return result;
        }
        if (pcollections.containsKey(value)) {
            JavaRDDLike<?, ?> rdd = pcollections.get(value).getRDD();
            @SuppressWarnings("unchecked")
            T res = (T) Iterables.getOnlyElement(rdd.collect());
            pobjects.put(value, res);
            return res;
        }
        throw new IllegalStateException("Cannot resolve un-known PObject: " + value);
    }

    @Override
    public <T> T getAggregatorValue(String named, Class<T> resultType) {
        return runtime.getAggregatorValue(AccumulatorSingleton.getInstance(jsc), named, resultType);
    }

    @Override
    public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> aggregator)
            throws AggregatorRetrievalException {
        return runtime.getAggregatorValues(AccumulatorSingleton.getInstance(jsc), aggregator);
    }

    @Override
    public MetricResults metrics() {
        throw new UnsupportedOperationException("The SparkRunner does not currently support metrics.");
    }

    @Override
    public <T> Iterable<T> get(PCollection<T> pcollection) {
        @SuppressWarnings("unchecked")
        RDDHolder<T> rddHolder = (RDDHolder<T>) pcollections.get(pcollection);
        Iterable<WindowedValue<T>> windowedValues = rddHolder.getValues(pcollection);
        return Iterables.transform(windowedValues, WindowingHelpers.<T>unwindowValueFunction());
    }

    <T> Iterable<WindowedValue<T>> getWindowedValues(PCollection<T> pcollection) {
        @SuppressWarnings("unchecked")
        RDDHolder<T> rddHolder = (RDDHolder<T>) pcollections.get(pcollection);
        return rddHolder.getValues(pcollection);
    }

    @Override
    public void close(boolean gracefully) {
        // graceful stop is used for streaming.
        SparkContextFactory.stopSparkContext(jsc);
    }

    /** The runner is blocking. */
    @Override
    public State getState() {
        return State.DONE;
    }

    @Override
    public State cancel() throws IOException {
        throw new UnsupportedOperationException("Spark runner EvaluationContext does not support cancel.");
    }

    @Override
    public State waitUntilFinish() {
        return waitUntilFinish(Duration.millis(-1));
    }

    @Override
    public State waitUntilFinish(Duration duration) {
        // This is no-op, since Spark runner in batch is blocking.
        // It needs to be updated once SparkRunner supports non-blocking execution:
        // https://issues.apache.org/jira/browse/BEAM-595
        return State.DONE;
    }
}