co.cask.cdap.etl.batch.spark.ETLSparkProgram.java Source code

Java tutorial

Introduction

Here is the source code for co.cask.cdap.etl.batch.spark.ETLSparkProgram.java

Source

/*
 * Copyright  2015-2016 Cask Data, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package co.cask.cdap.etl.batch.spark;

import co.cask.cdap.api.TxRunnable;
import co.cask.cdap.api.data.DatasetContext;
import co.cask.cdap.api.dataset.lib.KeyValue;
import co.cask.cdap.api.metrics.Metrics;
import co.cask.cdap.api.plugin.PluginContext;
import co.cask.cdap.api.spark.JavaSparkExecutionContext;
import co.cask.cdap.api.spark.JavaSparkMain;
import co.cask.cdap.etl.api.Transform;
import co.cask.cdap.etl.api.batch.BatchAggregator;
import co.cask.cdap.etl.api.batch.SparkCompute;
import co.cask.cdap.etl.api.batch.SparkSink;
import co.cask.cdap.etl.batch.BatchPhaseSpec;
import co.cask.cdap.etl.batch.PipelinePluginInstantiator;
import co.cask.cdap.etl.batch.TransformExecutorFactory;
import co.cask.cdap.etl.common.Constants;
import co.cask.cdap.etl.common.PipelinePhase;
import co.cask.cdap.etl.common.SetMultimapCodec;
import co.cask.cdap.etl.common.TransformExecutor;
import co.cask.cdap.etl.common.TransformResponse;
import co.cask.cdap.etl.planner.StageInfo;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.SetMultimap;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import scala.Tuple2;

import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;

/**
 * Spark program to run an ETL pipeline.
 */
public class ETLSparkProgram implements JavaSparkMain, TxRunnable {

    private static final Gson GSON = new GsonBuilder()
            .registerTypeAdapter(SetMultimap.class, new SetMultimapCodec<>()).create();

    private transient JavaSparkContext jsc;
    private transient JavaSparkExecutionContext sec;

    @Override
    public void run(final JavaSparkExecutionContext sec) throws Exception {
        this.jsc = new JavaSparkContext();
        this.sec = sec;

        // Execution the whole pipeline in one long transaction. This is because the Spark execution
        // currently share the same contract and API as the MapReduce one.
        // The API need to expose DatasetContext, hence it needs to be exeucted inside a transaction
        sec.execute(this);
    }

    @Override
    public void run(DatasetContext datasetContext) throws Exception {

        BatchPhaseSpec phaseSpec = GSON.fromJson(sec.getSpecification().getProperty(Constants.PIPELINEID),
                BatchPhaseSpec.class);
        Set<StageInfo> aggregators = phaseSpec.getPhase().getStagesOfType(BatchAggregator.PLUGIN_TYPE);
        String aggregatorName = null;
        if (!aggregators.isEmpty()) {
            aggregatorName = aggregators.iterator().next().getName();
        }

        SparkBatchSourceFactory sourceFactory;
        SparkBatchSinkFactory sinkFactory;
        Integer numPartitions;
        try (InputStream is = new FileInputStream(sec.getLocalizationContext().getLocalFile("ETLSpark.config"))) {
            sourceFactory = SparkBatchSourceFactory.deserialize(is);
            sinkFactory = SparkBatchSinkFactory.deserialize(is);
            numPartitions = new DataInputStream(is).readInt();
        }

        JavaPairRDD<Object, Object> rdd = sourceFactory.createRDD(sec, jsc, Object.class, Object.class);
        JavaPairRDD<String, Object> resultRDD = doTransform(sec, jsc, datasetContext, phaseSpec, rdd,
                aggregatorName, numPartitions);

        Set<StageInfo> stagesOfTypeSparkSink = phaseSpec.getPhase().getStagesOfType(SparkSink.PLUGIN_TYPE);
        Set<String> namesOfTypeSparkSink = new HashSet<>();

        for (StageInfo stageInfo : stagesOfTypeSparkSink) {
            namesOfTypeSparkSink.add(stageInfo.getName());
        }

        for (final String sinkName : phaseSpec.getPhase().getSinks()) {

            JavaPairRDD<String, Object> filteredResultRDD = resultRDD
                    .filter(new Function<Tuple2<String, Object>, Boolean>() {
                        @Override
                        public Boolean call(Tuple2<String, Object> v1) throws Exception {
                            return v1._1().equals(sinkName);
                        }
                    });

            if (namesOfTypeSparkSink.contains(sinkName)) {
                SparkSink sparkSink = sec.getPluginContext().newPluginInstance(sinkName);
                sparkSink.run(new BasicSparkExecutionPluginContext(sec, jsc, datasetContext, sinkName),
                        filteredResultRDD.values());
            } else {

                JavaPairRDD<Object, Object> sinkRDD = filteredResultRDD
                        .flatMapToPair(new PairFlatMapFunction<Tuple2<String, Object>, Object, Object>() {
                            @Override
                            public Iterable<Tuple2<Object, Object>> call(Tuple2<String, Object> input)
                                    throws Exception {
                                List<Tuple2<Object, Object>> result = new ArrayList<>();
                                KeyValue<Object, Object> keyValue = (KeyValue<Object, Object>) input._2();
                                result.add(new Tuple2<>(keyValue.getKey(), keyValue.getValue()));
                                return result;
                            }
                        });
                sinkFactory.writeFromRDD(sinkRDD, sec, sinkName, Object.class, Object.class);
            }
        }
    }

    private JavaPairRDD<String, Object> doTransform(JavaSparkExecutionContext sec, JavaSparkContext jsc,
            DatasetContext datasetContext, BatchPhaseSpec phaseSpec, JavaPairRDD<Object, Object> input,
            String aggregatorName, int numPartitions) throws Exception {

        Set<StageInfo> sparkComputes = phaseSpec.getPhase().getStagesOfType(SparkCompute.PLUGIN_TYPE);
        if (sparkComputes.isEmpty()) {
            // if this is not a phase with SparkCompute, do regular transform logic
            if (aggregatorName != null) {
                JavaPairRDD<Object, Object> preGroupRDD = input
                        .flatMapToPair(new PreGroupFunction(sec, aggregatorName));
                JavaPairRDD<Object, Iterable<Object>> groupedRDD = numPartitions < 0 ? preGroupRDD.groupByKey()
                        : preGroupRDD.groupByKey(numPartitions);
                return groupedRDD.flatMapToPair(new MapFunction<Iterable<Object>>(sec, null, aggregatorName, false))
                        .cache();
            } else {
                return input.flatMapToPair(new MapFunction<>(sec, null, null, false)).cache();
            }
        }

        // otherwise, special casing for SparkCompute type:

        // there should only be no other plugins of type Transform, because of how Smart Workflow breaks up the phases
        Set<StageInfo> stagesOfTypeTransform = phaseSpec.getPhase().getStagesOfType(Transform.PLUGIN_TYPE);
        Preconditions.checkArgument(stagesOfTypeTransform.isEmpty(),
                "Found non-empty set of transform plugins when expecting none: %s", stagesOfTypeTransform);

        // Smart Workflow should guarantee that only 1 SparkCompute exists per phase. This can be improved in the future
        // for efficiency.
        Preconditions.checkArgument(sparkComputes.size() == 1, "Expected only 1 SparkCompute: %s", sparkComputes);

        String sparkComputeName = Iterables.getOnlyElement(sparkComputes).getName();

        Set<String> sourceStages = phaseSpec.getPhase().getSources();
        Preconditions.checkArgument(sourceStages.size() == 1, "Expected only 1 source stage: %s", sourceStages);

        String sourceStageName = Iterables.getOnlyElement(sourceStages);

        Set<String> sourceNextStages = phaseSpec.getPhase().getStageOutputs(sourceStageName);
        Preconditions.checkArgument(sourceNextStages.size() == 1, "Expected only 1 stage after source stage: %s",
                sourceNextStages);

        Preconditions.checkArgument(sparkComputeName.equals(Iterables.getOnlyElement(sourceNextStages)),
                "Expected the single stage after the source stage to be the spark compute: %s", sparkComputeName);

        // phase starting from source to SparkCompute
        PipelinePhase sourcePhase = phaseSpec.getPhase().subsetTo(ImmutableSet.of(sparkComputeName));
        String sourcePipelineStr = GSON.toJson(new BatchPhaseSpec(phaseSpec.getPhaseName(), sourcePhase,
                phaseSpec.getResources(), phaseSpec.isStageLoggingEnabled(), phaseSpec.getConnectorDatasets()));

        JavaPairRDD<String, Object> sourceTransformed = input
                .flatMapToPair(new MapFunction<>(sec, sourcePipelineStr, null, true)).cache();

        SparkCompute sparkCompute = new PipelinePluginInstantiator(sec.getPluginContext(), phaseSpec)
                .newPluginInstance(sparkComputeName);
        JavaRDD<Object> sparkComputed = sparkCompute.transform(
                new BasicSparkExecutionPluginContext(sec, jsc, datasetContext, sparkComputeName),
                sourceTransformed.values());

        // phase starting from SparkCompute to sink(s)
        PipelinePhase sinkPhase = phaseSpec.getPhase().subsetFrom(ImmutableSet.of(sparkComputeName));
        String sinkPipelineStr = GSON.toJson(new BatchPhaseSpec(phaseSpec.getPhaseName(), sinkPhase,
                phaseSpec.getResources(), phaseSpec.isStageLoggingEnabled(), phaseSpec.getConnectorDatasets()));

        JavaPairRDD<String, Object> sinkTransformedValues = sparkComputed
                .flatMapToPair(new SingleTypeRDDMapFunction(sec, sinkPipelineStr)).cache();
        return sinkTransformedValues;
    }

    /**
     * Base function that knows how to set up a transform executor and run it.
     * Subclasses are responsible for massaging the output of the transform executor into the expected output,
     * and for configuring the transform executor with the right part of the pipeline.
     *
     * @param <IN> type of the input
     * @param <EXECUTOR_IN> type of the executor input
     * @param <KEY_OUT> type of the output key
     * @param <VAL_OUT> type of the output value
     */
    public abstract static class TransformExecutorFunction<IN, EXECUTOR_IN, KEY_OUT, VAL_OUT>
            implements PairFlatMapFunction<IN, KEY_OUT, VAL_OUT> {

        protected final PluginContext pluginContext;
        protected final Metrics metrics;
        protected final long logicalStartTime;
        protected final Map<String, String> runtimeArgs;
        protected final String pipelineStr;
        private transient TransformExecutor<EXECUTOR_IN> transformExecutor;

        public TransformExecutorFunction(JavaSparkExecutionContext sec, @Nullable String pipelineStr) {
            this.pluginContext = sec.getPluginContext();
            this.metrics = sec.getMetrics();
            this.logicalStartTime = sec.getLogicalStartTime();
            this.runtimeArgs = sec.getRuntimeArguments();
            this.pipelineStr = pipelineStr != null ? pipelineStr
                    : sec.getSpecification().getProperty(Constants.PIPELINEID);
        }

        @Override
        public Iterable<Tuple2<KEY_OUT, VAL_OUT>> call(IN input) throws Exception {
            if (transformExecutor == null) {
                // TODO: There is no way to call destroy() method on Transform
                // In fact, we can structure transform in a way that it doesn't need destroy
                // All current usage of destroy() in transform is actually for Source/Sink, which is actually
                // better do it in prepareRun and onRunFinish, which happen outside of the Job execution (true for both
                // Spark and MapReduce).
                BatchPhaseSpec phaseSpec = GSON.fromJson(pipelineStr, BatchPhaseSpec.class);
                PipelinePluginInstantiator pluginInstantiator = new PipelinePluginInstantiator(pluginContext,
                        phaseSpec);
                transformExecutor = initialize(phaseSpec, pluginInstantiator);
            }
            TransformResponse response = transformExecutor.runOneIteration(computeInputForExecutor(input));
            Iterable<Tuple2<KEY_OUT, VAL_OUT>> output = getOutput(response);
            transformExecutor.resetEmitter();
            return output;
        }

        protected abstract Iterable<Tuple2<KEY_OUT, VAL_OUT>> getOutput(TransformResponse transformResponse);

        protected abstract TransformExecutor<EXECUTOR_IN> initialize(BatchPhaseSpec phaseSpec,
                PipelinePluginInstantiator pluginInstantiator) throws Exception;

        protected abstract EXECUTOR_IN computeInputForExecutor(IN input);
    }

    /**
     * Performs all transforms before an aggregator plugin. Outputs tuples whose keys are the group key and values
     * are the group values that result by calling the aggregator's groupBy method.
     */
    public static final class PreGroupFunction
            extends TransformExecutorFunction<Tuple2<Object, Object>, KeyValue<Object, Object>, Object, Object> {
        private final String aggregatorName;

        public PreGroupFunction(JavaSparkExecutionContext sec, @Nullable String aggregatorName) {
            super(sec, null);
            this.aggregatorName = aggregatorName;
        }

        @Override
        protected Iterable<Tuple2<Object, Object>> getOutput(TransformResponse transformResponse) {
            List<Tuple2<Object, Object>> result = new ArrayList<>();
            for (Map.Entry<String, Collection<Object>> transformedEntry : transformResponse.getSinksResults()
                    .entrySet()) {
                for (Object output : transformedEntry.getValue()) {
                    result.add((Tuple2<Object, Object>) output);
                }
            }
            return result;
        }

        @Override
        protected TransformExecutor<KeyValue<Object, Object>> initialize(BatchPhaseSpec phaseSpec,
                PipelinePluginInstantiator pluginInstantiator) throws Exception {

            TransformExecutorFactory<KeyValue<Object, Object>> transformExecutorFactory = new SparkTransformExecutorFactory<>(
                    pluginContext, pluginInstantiator, metrics, logicalStartTime, runtimeArgs, true);
            PipelinePhase pipelinePhase = phaseSpec.getPhase().subsetTo(ImmutableSet.of(aggregatorName));
            return transformExecutorFactory.create(pipelinePhase);
        }

        @Override
        protected KeyValue<Object, Object> computeInputForExecutor(Tuple2<Object, Object> input) {
            return new KeyValue<>(input._1(), input._2());
        }
    }

    /**
     * Performs all transforms that happen after an aggregator, or if there is no aggregator at all.
     * Outputs tuples whose first item is the name of the sink that is being written to, and second item is
     * the key-value that should be written to that sink
     *
     * @param <T> type of the map output value
     */
    public static final class MapFunction<T>
            extends SingleTypeRDDMapFunction<Tuple2<Object, T>, KeyValue<Object, T>> {

        @Nullable
        private final String aggregatorName;
        private final boolean isBeforeBreak;

        public MapFunction(JavaSparkExecutionContext sec, String pipelineStr, String aggregatorName,
                boolean isBeforeBreak) {
            super(sec, pipelineStr);
            this.aggregatorName = aggregatorName;
            this.isBeforeBreak = isBeforeBreak;
        }

        @Override
        protected TransformExecutor<KeyValue<Object, T>> initialize(BatchPhaseSpec phaseSpec,
                PipelinePluginInstantiator pluginInstantiator) throws Exception {
            TransformExecutorFactory<KeyValue<Object, T>> transformExecutorFactory = new SparkTransformExecutorFactory<>(
                    pluginContext, pluginInstantiator, metrics, logicalStartTime, runtimeArgs, isBeforeBreak);

            PipelinePhase pipelinePhase = phaseSpec.getPhase();
            if (aggregatorName != null) {
                pipelinePhase = pipelinePhase.subsetFrom(ImmutableSet.of(aggregatorName));
            }

            return transformExecutorFactory.create(pipelinePhase);
        }

        @Override
        protected KeyValue<Object, T> computeInputForExecutor(Tuple2<Object, T> input) {
            return new KeyValue<>(input._1(), input._2());
        }
    }

    /**
     * Used for the transform after a SparkCompute. Otherwise, MapFunction only operates on RDD of JavaPairRDD.
     * In other words, it does not handle translation from Tuple to KeyValue, but directly sends the RDD type
     * to the TransformExecutor.
     * This allows operations on JavaRDD of single type. Handles no aggregation functionality, because it should not
     * be used in a phase with aggregations.
     *
     * @param <IN> type of the input
     * @param <EXECUTOR_IN> type of the input to the executor
     */
    public static class SingleTypeRDDMapFunction<IN, EXECUTOR_IN>
            extends TransformExecutorFunction<IN, EXECUTOR_IN, String, Object> {

        public SingleTypeRDDMapFunction(JavaSparkExecutionContext sec, String pipelineStr) {
            super(sec, pipelineStr);
        }

        @Override
        protected Iterable<Tuple2<String, Object>> getOutput(TransformResponse transformResponse) {
            List<Tuple2<String, Object>> result = new ArrayList<>();
            for (Map.Entry<String, Collection<Object>> transformedEntry : transformResponse.getSinksResults()
                    .entrySet()) {
                String sinkName = transformedEntry.getKey();
                for (Object outputRecord : transformedEntry.getValue()) {
                    result.add(new Tuple2<>(sinkName, outputRecord));
                }
            }
            return result;
        }

        @Override
        protected TransformExecutor<EXECUTOR_IN> initialize(BatchPhaseSpec phaseSpec,
                PipelinePluginInstantiator pluginInstantiator) throws Exception {

            TransformExecutorFactory<EXECUTOR_IN> transformExecutorFactory = new SparkTransformExecutorFactory<>(
                    pluginContext, pluginInstantiator, metrics, logicalStartTime, runtimeArgs, false);

            PipelinePhase pipelinePhase = phaseSpec.getPhase();
            return transformExecutorFactory.create(pipelinePhase);
        }

        @Override
        protected EXECUTOR_IN computeInputForExecutor(IN input) {
            // by default, have IN same as EXECUTOR_IN
            return (EXECUTOR_IN) input;
        }
    }
}