org.apache.beam.runners.spark.stateful.StateSpecFunctions.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.beam.runners.spark.stateful.StateSpecFunctions.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.beam.runners.spark.stateful;

import com.google.common.base.Stopwatch;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import java.io.Closeable;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.io.EmptyCheckpointMark;
import org.apache.beam.runners.spark.io.MicrobatchSource;
import org.apache.beam.runners.spark.io.SparkUnboundedSource.Metadata;
import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.io.Source;
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.sdk.metrics.MetricsContainer;
import org.apache.beam.sdk.metrics.MetricsEnvironment;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.spark.streaming.State;
import org.apache.spark.streaming.StateSpec;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;
import scala.Tuple2;
import scala.runtime.AbstractFunction3;

/**
 * A class containing {@link org.apache.spark.streaming.StateSpec} mappingFunctions.
 */
public class StateSpecFunctions {
    private static final Logger LOG = LoggerFactory.getLogger(StateSpecFunctions.class);

    /**
     * A helper class that is essentially a {@link Serializable} {@link AbstractFunction3}.
     */
    private abstract static class SerializableFunction3<T1, T2, T3, T4> extends AbstractFunction3<T1, T2, T3, T4>
            implements Serializable {
    }

    /**
     * A {@link org.apache.spark.streaming.StateSpec} function to support reading from
     * an {@link UnboundedSource}.
     *
     * <p>This StateSpec function expects the following:
     * <ul>
     * <li>Key: The (partitioned) Source to read from.</li>
     * <li>Value: An optional {@link UnboundedSource.CheckpointMark} to start from.</li>
     * <li>State: A byte representation of the (previously) persisted CheckpointMark.</li>
     * </ul>
     * And returns an iterator over all read values (for the micro-batch).
     *
     * <p>This stateful operation could be described as a flatMap over a single-element stream, which
     * outputs all the elements read from the {@link UnboundedSource} for this micro-batch.
     * Since micro-batches are bounded, the provided UnboundedSource is wrapped by a
     * {@link MicrobatchSource} that applies bounds in the form of duration and max records
     * (per micro-batch).
     *
     *
     * <p>In order to avoid using Spark Guava's classes which pollute the
     * classpath, we use the {@link StateSpec#function(scala.Function3)} signature which employs
     * scala's native {@link scala.Option}, instead of the
     * {@link StateSpec#function(org.apache.spark.api.java.function.Function3)} signature,
     * which employs Guava's {@link com.google.common.base.Optional}.
     *
     * <p>See also <a href="https://issues.apache.org/jira/browse/SPARK-4819">SPARK-4819</a>.</p>
     *
     * @param runtimeContext    A serializable {@link SparkRuntimeContext}.
     * @param <T>               The type of the input stream elements.
     * @param <CheckpointMarkT> The type of the {@link UnboundedSource.CheckpointMark}.
     * @return The appropriate {@link org.apache.spark.streaming.StateSpec} function.
     */
    public static <T, CheckpointMarkT extends UnboundedSource.CheckpointMark> scala.Function3<Source<T>, scala.Option<CheckpointMarkT>, State<Tuple2<byte[], Instant>>, Tuple2<Iterable<byte[]>, Metadata>> mapSourceFunction(
            final SparkRuntimeContext runtimeContext, final String stepName) {

        return new SerializableFunction3<Source<T>, Option<CheckpointMarkT>, State<Tuple2<byte[], Instant>>, Tuple2<Iterable<byte[]>, Metadata>>() {

            @Override
            public Tuple2<Iterable<byte[]>, Metadata> apply(Source<T> source,
                    scala.Option<CheckpointMarkT> startCheckpointMark, State<Tuple2<byte[], Instant>> state) {

                MetricsContainerStepMap metricsContainers = new MetricsContainerStepMap();
                MetricsContainer metricsContainer = metricsContainers.getContainer(stepName);

                // Add metrics container to the scope of org.apache.beam.sdk.io.Source.Reader methods
                // since they may report metrics.
                try (Closeable ignored = MetricsEnvironment.scopedMetricsContainer(metricsContainer)) {
                    // source as MicrobatchSource
                    MicrobatchSource<T, CheckpointMarkT> microbatchSource = (MicrobatchSource<T, CheckpointMarkT>) source;

                    // Initial high/low watermarks.
                    Instant lowWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE;
                    final Instant highWatermark;

                    // if state exists, use it, otherwise it's first time so use the startCheckpointMark.
                    // startCheckpointMark may be EmptyCheckpointMark (the Spark Java API tries to apply
                    // Optional(null)), which is handled by the UnboundedSource implementation.
                    Coder<CheckpointMarkT> checkpointCoder = microbatchSource.getCheckpointMarkCoder();
                    CheckpointMarkT checkpointMark;
                    if (state.exists()) {
                        // previous (output) watermark is now the low watermark.
                        lowWatermark = state.get()._2();
                        checkpointMark = CoderHelpers.fromByteArray(state.get()._1(), checkpointCoder);
                        LOG.info("Continue reading from an existing CheckpointMark.");
                    } else if (startCheckpointMark.isDefined()
                            && !startCheckpointMark.get().equals(EmptyCheckpointMark.get())) {
                        checkpointMark = startCheckpointMark.get();
                        LOG.info("Start reading from a provided CheckpointMark.");
                    } else {
                        checkpointMark = null;
                        LOG.info("No CheckpointMark provided, start reading from default.");
                    }

                    // create reader.
                    final MicrobatchSource.Reader/*<T>*/ microbatchReader;
                    final Stopwatch stopwatch = Stopwatch.createStarted();
                    long readDurationMillis = 0;

                    try {
                        microbatchReader = (MicrobatchSource.Reader) microbatchSource
                                .getOrCreateReader(runtimeContext.getPipelineOptions(), checkpointMark);
                    } catch (IOException e) {
                        throw new RuntimeException(e);
                    }

                    // read microbatch as a serialized collection.
                    final List<byte[]> readValues = new ArrayList<>();
                    WindowedValue.FullWindowedValueCoder<T> coder = WindowedValue.FullWindowedValueCoder
                            .of(source.getDefaultOutputCoder(), GlobalWindow.Coder.INSTANCE);
                    try {
                        // measure how long a read takes per-partition.
                        boolean finished = !microbatchReader.start();
                        while (!finished) {
                            final WindowedValue<T> wv = WindowedValue.of((T) microbatchReader.getCurrent(),
                                    microbatchReader.getCurrentTimestamp(), GlobalWindow.INSTANCE,
                                    PaneInfo.NO_FIRING);
                            readValues.add(CoderHelpers.toByteArray(wv, coder));
                            finished = !microbatchReader.advance();
                        }

                        // end-of-read watermark is the high watermark, but don't allow decrease.
                        final Instant sourceWatermark = microbatchReader.getWatermark();
                        highWatermark = sourceWatermark.isAfter(lowWatermark) ? sourceWatermark : lowWatermark;

                        readDurationMillis = stopwatch.stop().elapsed(TimeUnit.MILLISECONDS);

                        LOG.info("Source id {} spent {} millis on reading.", microbatchSource.getId(),
                                readDurationMillis);

                        // if the Source does not supply a CheckpointMark skip updating the state.
                        @SuppressWarnings("unchecked")
                        final CheckpointMarkT finishedReadCheckpointMark = (CheckpointMarkT) microbatchReader
                                .getCheckpointMark();
                        byte[] codedCheckpoint = new byte[0];
                        if (finishedReadCheckpointMark != null) {
                            codedCheckpoint = CoderHelpers.toByteArray(finishedReadCheckpointMark, checkpointCoder);
                        } else {
                            LOG.info("Skipping checkpoint marking because the reader failed to supply one.");
                        }
                        // persist the end-of-read (high) watermark for following read, where it will become
                        // the next low watermark.
                        state.update(new Tuple2<>(codedCheckpoint, highWatermark));
                    } catch (IOException e) {
                        throw new RuntimeException("Failed to read from reader.", e);
                    }

                    final ArrayList<byte[]> payload = Lists
                            .newArrayList(Iterators.unmodifiableIterator(readValues.iterator()));

                    return new Tuple2<>((Iterable<byte[]>) payload, new Metadata(readValues.size(), lowWatermark,
                            highWatermark, readDurationMillis, metricsContainers));

                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
        };
    }
}