org.apache.beam.runners.dataflow.worker.windmill.GrpcWindmillServer.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.beam.runners.dataflow.worker.windmill.GrpcWindmillServer.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.runners.dataflow.worker.windmill;

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.io.SequenceInputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.BlockingDeque;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils;
import org.apache.beam.runners.dataflow.worker.options.StreamingDataflowWorkerOptions;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetConfigRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetConfigResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ReportStatsRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ReportStatsResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitRequestChunk;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitWorkRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequestExtension;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkResponseChunk;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.sdk.util.BackOffUtils;
import org.apache.beam.sdk.util.FluentBackoff;
import org.apache.beam.sdk.util.Sleeper;
import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.CallCredentials;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.Channel;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.Status;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.StatusRuntimeException;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.auth.MoreCallCredentials;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.inprocess.InProcessChannelBuilder;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.netty.GrpcSslContexts;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.netty.NegotiationType;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.netty.NettyChannelBuilder;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Verify;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.net.HostAndPort;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** gRPC client for communicating with Windmill Service. */
// Very likely real potential for bugs - https://issues.apache.org/jira/browse/BEAM-6562
// Very likely real potential for bugs - https://issues.apache.org/jira/browse/BEAM-6564
@SuppressFBWarnings({ "JLM_JSR166_UTILCONCURRENT_MONITORENTER", "IS2_INCONSISTENT_SYNC" })
public class GrpcWindmillServer extends WindmillServerStub {
    private static final Logger LOG = LoggerFactory.getLogger(GrpcWindmillServer.class);

    // If a connection cannot be established, gRPC will fail fast so this deadline can be relatively
    // high.
    private static final long DEFAULT_UNARY_RPC_DEADLINE_SECONDS = 300;
    private static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300;
    // Stream clean close seconds must be set lower than the stream deadline seconds.
    private static final long DEFAULT_STREAM_CLEAN_CLOSE_SECONDS = 180;

    private static final Duration MIN_BACKOFF = Duration.millis(1);
    private static final Duration MAX_BACKOFF = Duration.standardSeconds(30);
    // Default gRPC streams to 2MB chunks, which has shown to be a large enough chunk size to reduce
    // per-chunk overhead, and small enough that we can still granularly flow-control.
    private static final int COMMIT_STREAM_CHUNK_SIZE = 2 << 20;
    private static final int GET_DATA_STREAM_CHUNK_SIZE = 2 << 20;

    private static final AtomicLong nextId = new AtomicLong(0);

    private final StreamingDataflowWorkerOptions options;
    private final int streamingRpcBatchLimit;
    private final List<CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub> stubList = new ArrayList<>();
    private final List<CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1BlockingStub> syncStubList = new ArrayList<>();
    private WindmillApplianceGrpc.WindmillApplianceBlockingStub syncApplianceStub = null;
    private long unaryDeadlineSeconds = DEFAULT_UNARY_RPC_DEADLINE_SECONDS;
    private ImmutableSet<HostAndPort> endpoints;
    private int logEveryNStreamFailures = 20;
    private Duration maxBackoff = MAX_BACKOFF;
    private final ThrottleTimer getWorkThrottleTimer = new ThrottleTimer();
    private final ThrottleTimer getDataThrottleTimer = new ThrottleTimer();
    private final ThrottleTimer commitWorkThrottleTimer = new ThrottleTimer();
    Random rand = new Random();

    private final Set<AbstractWindmillStream<?, ?>> streamRegistry = Collections
            .newSetFromMap(new ConcurrentHashMap<AbstractWindmillStream<?, ?>, Boolean>());

    public GrpcWindmillServer(StreamingDataflowWorkerOptions options) throws IOException {
        this.options = options;
        this.streamingRpcBatchLimit = options.getWindmillServiceStreamingRpcBatchLimit();
        this.endpoints = ImmutableSet.of();
        if (options.getWindmillServiceEndpoint() != null) {
            Set<HostAndPort> endpoints = new HashSet<>();
            for (String endpoint : Splitter.on(',').split(options.getWindmillServiceEndpoint())) {
                endpoints.add(HostAndPort.fromString(endpoint).withDefaultPort(options.getWindmillServicePort()));
            }
            initializeWindmillService(endpoints);
        } else if (!streamingEngineEnabled() && options.getLocalWindmillHostport() != null) {
            int portStart = options.getLocalWindmillHostport().lastIndexOf(':');
            String endpoint = options.getLocalWindmillHostport().substring(0, portStart);
            assert ("grpc:localhost".equals(endpoint));
            int port = Integer.parseInt(options.getLocalWindmillHostport().substring(portStart + 1));
            this.endpoints = ImmutableSet.<HostAndPort>of(HostAndPort.fromParts("localhost", port));
            initializeLocalHost(port);
        }
    }

    private GrpcWindmillServer(String name, boolean enableStreamingEngine) {
        this.options = PipelineOptionsFactory.create().as(StreamingDataflowWorkerOptions.class);
        this.streamingRpcBatchLimit = Integer.MAX_VALUE;
        options.setProject("project");
        options.setJobId("job");
        options.setWorkerId("worker");
        if (enableStreamingEngine) {
            List<String> experiments = this.options.getExperiments();
            if (experiments == null) {
                experiments = new ArrayList<>();
            }
            experiments.add(GcpOptions.STREAMING_ENGINE_EXPERIMENT);
            options.setExperiments(experiments);
        }
        this.stubList.add(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel(name)));
    }

    private boolean streamingEngineEnabled() {
        return options.isEnableStreamingEngine();
    }

    @Override
    public synchronized void setWindmillServiceEndpoints(Set<HostAndPort> endpoints) throws IOException {
        Preconditions.checkNotNull(endpoints);
        if (endpoints.equals(this.endpoints)) {
            // The endpoints are equal don't recreate the stubs.
            return;
        }
        LOG.info("Creating a new windmill stub, endpoints: {}", endpoints);
        if (this.endpoints != null) {
            LOG.info("Previous windmill stub endpoints: {}", this.endpoints);
        }
        initializeWindmillService(endpoints);
    }

    @Override
    public synchronized boolean isReady() {
        return !stubList.isEmpty();
    }

    private synchronized void initializeLocalHost(int port) throws IOException {
        this.logEveryNStreamFailures = 1;
        this.maxBackoff = Duration.millis(500);
        this.unaryDeadlineSeconds = 10; // For local testing use a short deadline.
        Channel channel = localhostChannel(port);
        if (streamingEngineEnabled()) {
            this.stubList.add(CloudWindmillServiceV1Alpha1Grpc.newStub(channel));
            this.syncStubList.add(CloudWindmillServiceV1Alpha1Grpc.newBlockingStub(channel));
        } else {
            this.syncApplianceStub = WindmillApplianceGrpc.newBlockingStub(channel);
        }
    }

    /**
     * Create a wrapper around credentials callback that delegates to the underlying vendored {@link
     * com.google.auth.RequestMetadataCallback}. Note that this class should override every method
     * that is not final and not static and call the delegate directly.
     *
     * <p>TODO: Replace this with an auto generated proxy which calls the underlying implementation
     * delegate to reduce maintenance burden.
     */
    private static class VendoredRequestMetadataCallbackAdapter implements com.google.auth.RequestMetadataCallback {
        private final org.apache.beam.vendor.grpc.v1p21p0.com.google.auth.RequestMetadataCallback callback;

        private VendoredRequestMetadataCallbackAdapter(
                org.apache.beam.vendor.grpc.v1p21p0.com.google.auth.RequestMetadataCallback callback) {
            this.callback = callback;
        }

        @Override
        public void onSuccess(Map<String, List<String>> metadata) {
            callback.onSuccess(metadata);
        }

        @Override
        public void onFailure(Throwable exception) {
            callback.onFailure(exception);
        }
    }

    /**
     * Create a wrapper around credentials that delegates to the underlying {@link
     * com.google.auth.Credentials}. Note that this class should override every method that is not
     * final and not static and call the delegate directly.
     *
     * <p>TODO: Replace this with an auto generated proxy which calls the underlying implementation
     * delegate to reduce maintenance burden.
     */
    private static class VendoredCredentialsAdapter
            extends org.apache.beam.vendor.grpc.v1p21p0.com.google.auth.Credentials {
        private final com.google.auth.Credentials credentials;

        private VendoredCredentialsAdapter(com.google.auth.Credentials credentials) {
            this.credentials = credentials;
        }

        @Override
        public String getAuthenticationType() {
            return credentials.getAuthenticationType();
        }

        @Override
        public Map<String, List<String>> getRequestMetadata() throws IOException {
            return credentials.getRequestMetadata();
        }

        @Override
        public void getRequestMetadata(final URI uri, Executor executor,
                final org.apache.beam.vendor.grpc.v1p21p0.com.google.auth.RequestMetadataCallback callback) {
            credentials.getRequestMetadata(uri, executor, new VendoredRequestMetadataCallbackAdapter(callback));
        }

        @Override
        public Map<String, List<String>> getRequestMetadata(URI uri) throws IOException {
            return credentials.getRequestMetadata(uri);
        }

        @Override
        public boolean hasRequestMetadata() {
            return credentials.hasRequestMetadata();
        }

        @Override
        public boolean hasRequestMetadataOnly() {
            return credentials.hasRequestMetadataOnly();
        }

        @Override
        public void refresh() throws IOException {
            credentials.refresh();
        }
    }

    private synchronized void initializeWindmillService(Set<HostAndPort> endpoints) throws IOException {
        LOG.info("Initializing Streaming Engine GRPC client for endpoints: {}", endpoints);
        this.stubList.clear();
        this.syncStubList.clear();
        this.endpoints = ImmutableSet.<HostAndPort>copyOf(endpoints);
        for (HostAndPort endpoint : this.endpoints) {
            if ("localhost".equals(endpoint.getHost())) {
                initializeLocalHost(endpoint.getPort());
            } else {
                CallCredentials creds = MoreCallCredentials
                        .from(new VendoredCredentialsAdapter(options.getGcpCredential()));
                this.stubList.add(CloudWindmillServiceV1Alpha1Grpc.newStub(remoteChannel(endpoint))
                        .withCallCredentials(creds));
                this.syncStubList.add(CloudWindmillServiceV1Alpha1Grpc.newBlockingStub(remoteChannel(endpoint))
                        .withCallCredentials(creds));
            }
        }
    }

    @VisibleForTesting
    static GrpcWindmillServer newTestInstance(String name, boolean enableStreamingEngine) {
        return new GrpcWindmillServer(name, enableStreamingEngine);
    }

    private Channel inProcessChannel(String name) {
        return InProcessChannelBuilder.forName(name).directExecutor().build();
    }

    private Channel localhostChannel(int port) {
        return NettyChannelBuilder.forAddress("localhost", port).maxInboundMessageSize(java.lang.Integer.MAX_VALUE)
                .negotiationType(NegotiationType.PLAINTEXT).build();
    }

    private Channel remoteChannel(HostAndPort endpoint) throws IOException {
        return NettyChannelBuilder.forAddress(endpoint.getHost(), endpoint.getPort())
                .maxInboundMessageSize(java.lang.Integer.MAX_VALUE).negotiationType(NegotiationType.TLS)
                // Set ciphers(null) to not use GCM, which is disabled for Dataflow
                // due to it being horribly slow.
                .sslContext(GrpcSslContexts.forClient().ciphers(null).build()).build();
    }

    private synchronized CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub() {
        if (stubList.isEmpty()) {
            throw new RuntimeException("windmillServiceEndpoint has not been set");
        }
        if (stubList.size() == 1) {
            return stubList.get(0);
        }
        return stubList.get(rand.nextInt(stubList.size()));
    }

    private synchronized CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1BlockingStub syncStub() {
        if (syncStubList.isEmpty()) {
            throw new RuntimeException("windmillServiceEndpoint has not been set");
        }
        if (syncStubList.size() == 1) {
            return syncStubList.get(0);
        }
        return syncStubList.get(rand.nextInt(syncStubList.size()));
    }

    @Override
    public void appendSummaryHtml(PrintWriter writer) {
        writer.write("Active Streams:<br>");
        for (AbstractWindmillStream<?, ?> stream : streamRegistry) {
            stream.appendSummaryHtml(writer);
            writer.write("<br>");
        }
    }

    // Configure backoff to retry calls forever, with a maximum sane retry interval.
    private BackOff grpcBackoff() {
        return FluentBackoff.DEFAULT.withInitialBackoff(MIN_BACKOFF).withMaxBackoff(maxBackoff).backoff();
    }

    private <ResponseT> ResponseT callWithBackoff(Supplier<ResponseT> function) {
        BackOff backoff = grpcBackoff();
        int rpcErrors = 0;
        while (true) {
            try {
                return function.get();
            } catch (StatusRuntimeException e) {
                try {
                    if (++rpcErrors % 20 == 0) {
                        LOG.warn("Many exceptions calling gRPC. Last exception: {} with status {}", e,
                                e.getStatus());
                    }
                    if (!BackOffUtils.next(Sleeper.DEFAULT, backoff)) {
                        throw new WindmillServerStub.RpcException(e);
                    }
                } catch (IOException | InterruptedException i) {
                    if (i instanceof InterruptedException) {
                        Thread.currentThread().interrupt();
                    }
                    WindmillServerStub.RpcException rpcException = new WindmillServerStub.RpcException(e);
                    rpcException.addSuppressed(i);
                    throw rpcException;
                }
            }
        }
    }

    @Override
    public GetWorkResponse getWork(GetWorkRequest request) {
        if (syncApplianceStub == null) {
            return callWithBackoff(() -> syncStub().withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS)
                    .getWork(request.toBuilder().setJobId(options.getJobId()).setProjectId(options.getProject())
                            .setWorkerId(options.getWorkerId()).build()));
        } else {
            return callWithBackoff(() -> syncApplianceStub.withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS)
                    .getWork(request));
        }
    }

    @Override
    public GetDataResponse getData(GetDataRequest request) {
        if (syncApplianceStub == null) {
            return callWithBackoff(
                    () -> syncStub().withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS).getData(request
                            .toBuilder().setJobId(options.getJobId()).setProjectId(options.getProject()).build()));
        } else {
            return callWithBackoff(() -> syncApplianceStub.withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS)
                    .getData(request));
        }
    }

    @Override
    public CommitWorkResponse commitWork(CommitWorkRequest request) {
        if (syncApplianceStub == null) {
            return callWithBackoff(
                    () -> syncStub().withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS).commitWork(request
                            .toBuilder().setJobId(options.getJobId()).setProjectId(options.getProject()).build()));
        } else {
            return callWithBackoff(() -> syncApplianceStub.withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS)
                    .commitWork(request));
        }
    }

    @Override
    public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver receiver) {
        return new GrpcGetWorkStream(GetWorkRequest.newBuilder(request).setJobId(options.getJobId())
                .setProjectId(options.getProject()).setWorkerId(options.getWorkerId()).build(), receiver);
    }

    @Override
    public GetDataStream getDataStream() {
        return new GrpcGetDataStream();
    }

    @Override
    public CommitWorkStream commitWorkStream() {
        return new GrpcCommitWorkStream();
    }

    @Override
    public GetConfigResponse getConfig(GetConfigRequest request) {
        if (syncApplianceStub == null) {
            throw new RpcException(
                    new UnsupportedOperationException("GetConfig not supported with windmill service."));
        } else {
            return callWithBackoff(() -> syncApplianceStub.withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS)
                    .getConfig(request));
        }
    }

    @Override
    public ReportStatsResponse reportStats(ReportStatsRequest request) {
        if (syncApplianceStub == null) {
            throw new RpcException(
                    new UnsupportedOperationException("ReportStats not supported with windmill service."));
        } else {
            return callWithBackoff(() -> syncApplianceStub.withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS)
                    .reportStats(request));
        }
    }

    @Override
    public long getAndResetThrottleTime() {
        return getWorkThrottleTimer.getAndResetThrottleTime() + getDataThrottleTimer.getAndResetThrottleTime()
                + commitWorkThrottleTimer.getAndResetThrottleTime();
    }

    private JobHeader makeHeader() {
        return JobHeader.newBuilder().setJobId(options.getJobId()).setProjectId(options.getProject())
                .setWorkerId(options.getWorkerId()).build();
    }

    /** Returns a long that is unique to this process. */
    private static long uniqueId() {
        return nextId.incrementAndGet();
    }

    /**
     * Base class for persistent streams connecting to Windmill.
     *
     * <p>This class handles the underlying gRPC StreamObservers, and automatically reconnects the
     * stream if it is broken. Subclasses are responsible for retrying requests that have been lost on
     * a broken stream.
     *
     * <p>Subclasses should override onResponse to handle responses from the server, and onNewStream
     * to perform any work that must be done when a new stream is created, such as sending headers or
     * retrying requests.
     *
     * <p>send and startStream should not be called from onResponse; use executor() instead.
     *
     * <p>Synchronization on this is used to synchronize the gRpc stream state and internal data
     * structures. Since grpc channel operations may block, synchronization on this stream may also
     * block. This is generally not a problem since streams are used in a single-threaded manner.
     * However some accessors used for status page and other debugging need to take care not to
     * require synchronizing on this.
     */
    private abstract class AbstractWindmillStream<RequestT, ResponseT> implements WindmillStream {
        private final StreamObserverFactory streamObserverFactory = StreamObserverFactory.direct();
        private final Function<StreamObserver<ResponseT>, StreamObserver<RequestT>> clientFactory;
        private final Executor executor = Executors.newSingleThreadExecutor();

        // The following should be protected by synchronizing on this, except for
        // the atomics which may be read atomically for status pages.
        private StreamObserver<RequestT> requestObserver;
        private final AtomicLong startTimeMs = new AtomicLong();
        private final AtomicInteger errorCount = new AtomicInteger();
        private final BackOff backoff = grpcBackoff();
        private final AtomicLong sleepUntil = new AtomicLong();
        protected final AtomicBoolean clientClosed = new AtomicBoolean();
        private final CountDownLatch finishLatch = new CountDownLatch(1);

        protected AbstractWindmillStream(
                Function<StreamObserver<ResponseT>, StreamObserver<RequestT>> clientFactory) {
            this.clientFactory = clientFactory;
        }

        /** Called on each response from the server */
        protected abstract void onResponse(ResponseT response);

        /** Called when a new underlying stream to the server has been opened. */
        protected abstract void onNewStream();

        /** Returns whether there are any pending requests that should be retried on a stream break. */
        protected abstract boolean hasPendingRequests();

        /**
         * Called when the stream is throttled due to resource exhausted errors. Will be called for each
         * resource exhausted error not just the first. onResponse() must stop throttling on reciept of
         * the first good message.
         */
        protected abstract void startThrottleTimer();

        /** Send a request to the server. */
        protected final synchronized void send(RequestT request) {
            requestObserver.onNext(request);
        }

        /** Starts the underlying stream. */
        protected final void startStream() {
            // Add the stream to the registry after it has been fully constructed.
            streamRegistry.add(this);
            BackOff backoff = grpcBackoff();
            while (true) {
                try {
                    synchronized (this) {
                        startTimeMs.set(Instant.now().getMillis());
                        requestObserver = streamObserverFactory.from(clientFactory, new ResponseObserver());
                        onNewStream();
                        if (clientClosed.get()) {
                            close();
                        }
                        return;
                    }
                } catch (Exception e) {
                    LOG.error("Failed to create new stream, retrying: ", e);
                    try {
                        long sleep = backoff.nextBackOffMillis();
                        sleepUntil.set(Instant.now().getMillis() + sleep);
                        Thread.sleep(sleep);
                    } catch (InterruptedException i) {
                        // Keep trying to create the stream.
                    } catch (IOException i) {
                        // Ignore.
                    }
                }
            }
        }

        protected final Executor executor() {
            return executor;
        }

        // Care is taken that synchronization on this is unnecessary for all status page information.
        // Blocking sends are made beneath this stream object's lock which could block status page
        // rendering.
        public final void appendSummaryHtml(PrintWriter writer) {
            appendSpecificHtml(writer);
            if (errorCount.get() > 0) {
                writer.format(", %d errors", errorCount.get());
            }
            if (clientClosed.get()) {
                writer.write(", client closed");
            }
            long sleepLeft = sleepUntil.get() - Instant.now().getMillis();
            if (sleepLeft > 0) {
                writer.format(", %dms backoff remaining", sleepLeft);
            }
            writer.format(", current stream is %dms old", Instant.now().getMillis() - startTimeMs.get());
        }

        // Don't require synchronization on stream, see the appendSummaryHtml comment.
        protected abstract void appendSpecificHtml(PrintWriter writer);

        private class ResponseObserver implements StreamObserver<ResponseT> {
            @Override
            public void onNext(ResponseT response) {
                try {
                    backoff.reset();
                } catch (IOException e) {
                    // Ignore.
                }
                onResponse(response);
            }

            @Override
            public void onError(Throwable t) {
                onStreamFinished(t);
            }

            @Override
            public void onCompleted() {
                onStreamFinished(null);
            }

            private void onStreamFinished(@Nullable Throwable t) {
                synchronized (this) {
                    if (clientClosed.get() && !hasPendingRequests()) {
                        streamRegistry.remove(AbstractWindmillStream.this);
                        finishLatch.countDown();
                        return;
                    }
                }
                if (t != null) {
                    Status status = null;
                    if (t instanceof StatusRuntimeException) {
                        status = ((StatusRuntimeException) t).getStatus();
                    }
                    if (errorCount.incrementAndGet() % logEveryNStreamFailures == 0) {
                        LOG.warn("{} streaming Windmill RPC errors for a stream, last was: {} with status {}",
                                errorCount.get(), t.toString(), status);
                    }
                    // If the stream was stopped due to a resource exhausted error then we are throttled.
                    if (status != null && status.getCode() == Status.Code.RESOURCE_EXHAUSTED) {
                        startThrottleTimer();
                    }

                    try {
                        long sleep = backoff.nextBackOffMillis();
                        sleepUntil.set(Instant.now().getMillis() + sleep);
                        Thread.sleep(sleep);
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                    } catch (IOException e) {
                        // Ignore.
                    }
                }
                executor.execute(AbstractWindmillStream.this::startStream);
            }
        }

        @Override
        public final synchronized void close() {
            // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream.
            clientClosed.set(true);
            requestObserver.onCompleted();
        }

        @Override
        public final boolean awaitTermination(int time, TimeUnit unit) throws InterruptedException {
            return finishLatch.await(time, unit);
        }

        @Override
        public final void closeAfterDefaultTimeout() throws InterruptedException {
            if (!finishLatch.await(DEFAULT_STREAM_CLEAN_CLOSE_SECONDS, TimeUnit.SECONDS)) {
                // If the stream did not close due to error in the specified amount of time, half-close
                // the stream cleanly.
                close();
            }
        }

        @Override
        public final Instant startTime() {
            return new Instant(startTimeMs.get());
        }
    }

    private class GrpcGetWorkStream
            extends AbstractWindmillStream<StreamingGetWorkRequest, StreamingGetWorkResponseChunk>
            implements GetWorkStream {
        private final GetWorkRequest request;
        private final WorkItemReceiver receiver;
        private final Map<Long, WorkItemBuffer> buffers = new ConcurrentHashMap<>();
        private final AtomicLong inflightMessages = new AtomicLong();
        private final AtomicLong inflightBytes = new AtomicLong();

        private GrpcGetWorkStream(GetWorkRequest request, WorkItemReceiver receiver) {
            super(responseObserver -> stub()
                    .withDeadlineAfter(DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS)
                    .getWorkStream(responseObserver));
            this.request = request;
            this.receiver = receiver;
            startStream();
        }

        @Override
        protected synchronized void onNewStream() {
            buffers.clear();
            inflightMessages.set(request.getMaxItems());
            inflightBytes.set(request.getMaxBytes());
            send(StreamingGetWorkRequest.newBuilder().setRequest(request).build());
        }

        @Override
        protected boolean hasPendingRequests() {
            return false;
        }

        @Override
        public void appendSpecificHtml(PrintWriter writer) {
            // Number of buffers is same as distict workers that sent work on this stream.
            writer.format("GetWorkStream: %d buffers, %d inflight messages allowed, %d inflight bytes allowed",
                    buffers.size(), inflightMessages.intValue(), inflightBytes.intValue());
        }

        @Override
        protected void onResponse(StreamingGetWorkResponseChunk chunk) {
            getWorkThrottleTimer.stop();
            long id = chunk.getStreamId();

            WorkItemBuffer buffer = buffers.computeIfAbsent(id, (Long l) -> new WorkItemBuffer());
            buffer.append(chunk);

            if (chunk.getRemainingBytesForWorkItem() == 0) {
                long size = buffer.bufferedSize();
                buffer.runAndReset();

                // Record the fact that there are now fewer outstanding messages and bytes on the stream.
                long numInflight = inflightMessages.decrementAndGet();
                long bytesInflight = inflightBytes.addAndGet(-size);

                // If the outstanding items or bytes limit has gotten too low, top both off with a
                // GetWorkExtension.  The goal is to keep the limits relatively close to their maximum
                // values without sending too many extension requests.
                if (numInflight < request.getMaxItems() / 2 || bytesInflight < request.getMaxBytes() / 2) {
                    long moreItems = request.getMaxItems() - numInflight;
                    long moreBytes = request.getMaxBytes() - bytesInflight;
                    inflightMessages.getAndAdd(moreItems);
                    inflightBytes.getAndAdd(moreBytes);
                    final StreamingGetWorkRequest extension = StreamingGetWorkRequest.newBuilder()
                            .setRequestExtension(StreamingGetWorkRequestExtension.newBuilder()
                                    .setMaxItems(moreItems).setMaxBytes(moreBytes))
                            .build();
                    executor().execute(() -> {
                        try {
                            send(extension);
                        } catch (IllegalStateException e) {
                            // Stream was closed.
                        }
                    });
                }
            }
        }

        @Override
        protected void startThrottleTimer() {
            getWorkThrottleTimer.start();
        }

        private class WorkItemBuffer {
            private String computation;
            private Instant inputDataWatermark;
            private Instant synchronizedProcessingTime;
            private ByteString data = ByteString.EMPTY;
            private long bufferedSize = 0;

            private void setMetadata(Windmill.ComputationWorkItemMetadata metadata) {
                this.computation = metadata.getComputationId();
                this.inputDataWatermark = WindmillTimeUtils
                        .windmillToHarnessWatermark(metadata.getInputDataWatermark());
                this.synchronizedProcessingTime = WindmillTimeUtils
                        .windmillToHarnessWatermark(metadata.getDependentRealtimeInputWatermark());
            }

            public void append(StreamingGetWorkResponseChunk chunk) {
                if (chunk.hasComputationMetadata()) {
                    setMetadata(chunk.getComputationMetadata());
                }

                this.data = data.concat(chunk.getSerializedWorkItem());
                this.bufferedSize += chunk.getSerializedWorkItem().size();
            }

            public long bufferedSize() {
                return bufferedSize;
            }

            public void runAndReset() {
                try {
                    receiver.receiveWork(computation, inputDataWatermark, synchronizedProcessingTime,
                            Windmill.WorkItem.parseFrom(data.newInput()));
                } catch (IOException e) {
                    LOG.error("Failed to parse work item from stream: ", e);
                }
                data = ByteString.EMPTY;
                bufferedSize = 0;
            }
        }
    }

    private class GrpcGetDataStream extends
            AbstractWindmillStream<StreamingGetDataRequest, StreamingGetDataResponse> implements GetDataStream {
        private class QueuedRequest {
            public QueuedRequest(String computation, KeyedGetDataRequest request) {
                this.id = uniqueId();
                this.globalDataRequest = null;
                this.dataRequest = ComputationGetDataRequest.newBuilder().setComputationId(computation)
                        .addRequests(request).build();
                this.byteSize = this.dataRequest.getSerializedSize();
            }

            public QueuedRequest(GlobalDataRequest request) {
                this.id = uniqueId();
                this.globalDataRequest = request;
                this.dataRequest = null;
                this.byteSize = this.globalDataRequest.getSerializedSize();
            }

            final long id;
            final long byteSize;
            final GlobalDataRequest globalDataRequest;
            final ComputationGetDataRequest dataRequest;
            AppendableInputStream responseStream = null;
        }

        private class QueuedBatch {
            public QueuedBatch() {
            }

            final List<QueuedRequest> requests = new ArrayList<>();
            long byteSize = 0;
            boolean finalized = false;
            final CountDownLatch sent = new CountDownLatch(1);
        };

        private final Deque<QueuedBatch> batches = new ConcurrentLinkedDeque<>();
        private final Map<Long, AppendableInputStream> pending = new ConcurrentHashMap<>();

        @Override
        public void appendSpecificHtml(PrintWriter writer) {
            writer.format("GetDataStream: %d pending on-wire, %d queued batches", pending.size(), batches.size());
        }

        GrpcGetDataStream() {
            super(responseObserver -> stub()
                    .withDeadlineAfter(DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS)
                    .getDataStream(responseObserver));
            startStream();
        }

        @Override
        protected synchronized void onNewStream() {
            send(StreamingGetDataRequest.newBuilder().setHeader(makeHeader()).build());

            if (clientClosed.get()) {
                // We rely on close only occurring after all methods on the stream have returned.
                // Since the requestKeyedData and requestGlobalData methods are blocking this
                // means there should be no pending requests.
                Verify.verify(!hasPendingRequests());
            } else {
                for (AppendableInputStream responseStream : pending.values()) {
                    responseStream.cancel();
                }
            }
        }

        @Override
        protected boolean hasPendingRequests() {
            return !pending.isEmpty() || !batches.isEmpty();
        }

        @Override
        protected void onResponse(StreamingGetDataResponse chunk) {
            Preconditions.checkArgument(chunk.getRequestIdCount() == chunk.getSerializedResponseCount());
            Preconditions
                    .checkArgument(chunk.getRemainingBytesForResponse() == 0 || chunk.getRequestIdCount() == 1);
            getDataThrottleTimer.stop();

            for (int i = 0; i < chunk.getRequestIdCount(); ++i) {
                AppendableInputStream responseStream = pending.get(chunk.getRequestId(i));
                Verify.verify(responseStream != null, "No pending response stream");
                responseStream.append(chunk.getSerializedResponse(i).newInput());
                if (chunk.getRemainingBytesForResponse() == 0) {
                    responseStream.complete();
                }
            }
        }

        @Override
        protected void startThrottleTimer() {
            getDataThrottleTimer.start();
        }

        @Override
        public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataRequest request) {
            return issueRequest(new QueuedRequest(computation, request), KeyedGetDataResponse::parseFrom);
        }

        @Override
        public GlobalData requestGlobalData(GlobalDataRequest request) {
            return issueRequest(new QueuedRequest(request), GlobalData::parseFrom);
        }

        @Override
        public void refreshActiveWork(Map<String, List<KeyedGetDataRequest>> active) {
            long builderBytes = 0;
            StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder();
            for (Map.Entry<String, List<KeyedGetDataRequest>> entry : active.entrySet()) {
                for (KeyedGetDataRequest request : entry.getValue()) {
                    // Calculate the bytes with some overhead for proto encoding.
                    long bytes = (long) entry.getKey().length() + request.getSerializedSize() + 10;
                    if (builderBytes > 0 && (builderBytes + bytes > GET_DATA_STREAM_CHUNK_SIZE
                            || builder.getRequestIdCount() >= streamingRpcBatchLimit)) {
                        send(builder.build());
                        builderBytes = 0;
                        builder.clear();
                    }
                    builderBytes += bytes;
                    builder.addStateRequest(ComputationGetDataRequest.newBuilder().setComputationId(entry.getKey())
                            .addRequests(request));
                }
            }
            if (builderBytes > 0) {
                send(builder.build());
            }
        }

        private <ResponseT> ResponseT issueRequest(QueuedRequest request, ParseFn<ResponseT> parseFn) {
            while (true) {
                request.responseStream = new AppendableInputStream();
                try {
                    queueRequestAndWait(request);
                    return parseFn.parse(request.responseStream);
                } catch (CancellationException e) {
                    // Retry issuing the request since the response stream was cancelled.
                    continue;
                } catch (IOException e) {
                    LOG.error("Parsing GetData response failed: ", e);
                    continue;
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    throw new RuntimeException(e);
                } finally {
                    pending.remove(request.id);
                }
            }
        }

        private void queueRequestAndWait(QueuedRequest request) throws InterruptedException {
            QueuedBatch batch;
            boolean responsibleForSend = false;
            CountDownLatch waitForSendLatch = null;
            synchronized (batches) {
                batch = batches.isEmpty() ? null : batches.getLast();
                if (batch == null || batch.finalized || batch.requests.size() >= streamingRpcBatchLimit
                        || batch.byteSize + request.byteSize > GET_DATA_STREAM_CHUNK_SIZE) {
                    if (batch != null) {
                        waitForSendLatch = batch.sent;
                    }
                    batch = new QueuedBatch();
                    batches.addLast(batch);
                    responsibleForSend = true;
                }
                batch.requests.add(request);
                batch.byteSize += request.byteSize;
            }
            if (responsibleForSend) {
                if (waitForSendLatch == null) {
                    // If there was not a previous batch wait a little while to improve
                    // batching.
                    Thread.sleep(1);
                } else {
                    waitForSendLatch.await();
                }
                // Finalize the batch so that no additional requests will be added.  Leave the batch in the
                // queue so that a subsequent batch will wait for it's completion.
                synchronized (batches) {
                    Verify.verify(batch == batches.peekFirst());
                    batch.finalized = true;
                }
                sendBatch(batch.requests);
                synchronized (batches) {
                    Verify.verify(batch == batches.pollFirst());
                }
                // Notify all waiters with requests in this batch as well as the sender
                // of the next batch (if one exists).
                batch.sent.countDown();
            } else {
                // Wait for this batch to be sent before parsing the response.
                batch.sent.await();
            }
        }

        private void sendBatch(List<QueuedRequest> requests) {
            StreamingGetDataRequest batchedRequest = flushToBatch(requests);
            synchronized (this) {
                // Synchronization of pending inserts is necessary with send to ensure duplicates are not
                // sent on stream reconnect.
                for (QueuedRequest request : requests) {
                    Verify.verify(pending.put(request.id, request.responseStream) == null);
                }
                try {
                    send(batchedRequest);
                } catch (IllegalStateException e) {
                    // The stream broke before this call went through; onNewStream will retry the fetch.
                }
            }
        }

        private StreamingGetDataRequest flushToBatch(List<QueuedRequest> requests) {
            // Put all global data requests first because there is only a single repeated field for
            // request ids and the initial ids correspond to global data requests if they are present.
            requests.sort((QueuedRequest r1, QueuedRequest r2) -> {
                boolean r1gd = r1.globalDataRequest != null;
                boolean r2gd = r2.globalDataRequest != null;
                return r1gd == r2gd ? 0 : (r1gd ? -1 : 1);
            });
            StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder();
            for (QueuedRequest request : requests) {
                builder.addRequestId(request.id);
                if (request.globalDataRequest == null) {
                    builder.addStateRequest(request.dataRequest);
                } else {
                    builder.addGlobalDataRequest(request.globalDataRequest);
                }
            }
            return builder.build();
        }
    }

    private class GrpcCommitWorkStream
            extends AbstractWindmillStream<StreamingCommitWorkRequest, StreamingCommitResponse>
            implements CommitWorkStream {
        private class PendingRequest {
            private final String computation;
            private final WorkItemCommitRequest request;
            private final Consumer<CommitStatus> onDone;

            PendingRequest(String computation, WorkItemCommitRequest request, Consumer<CommitStatus> onDone) {
                this.computation = computation;
                this.request = request;
                this.onDone = onDone;
            }

            long getBytes() {
                return (long) request.getSerializedSize() + computation.length();
            }
        }

        private final Map<Long, PendingRequest> pending = new ConcurrentHashMap<>();

        private class Batcher {
            long queuedBytes = 0;
            Map<Long, PendingRequest> queue = new HashMap<>();

            boolean canAccept(PendingRequest request) {
                return queue.isEmpty() || (queue.size() < streamingRpcBatchLimit
                        && (request.getBytes() + queuedBytes) < COMMIT_STREAM_CHUNK_SIZE);
            }

            void add(long id, PendingRequest request) {
                assert (canAccept(request));
                queuedBytes += request.getBytes();
                queue.put(id, request);
            }

            void flush() {
                flushInternal(queue);
                queuedBytes = 0;
            }
        }

        private final Batcher batcher = new Batcher();

        GrpcCommitWorkStream() {
            super(responseObserver -> stub()
                    .withDeadlineAfter(DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS)
                    .commitWorkStream(responseObserver));
            startStream();
        }

        @Override
        public void appendSpecificHtml(PrintWriter writer) {
            writer.format("CommitWorkStream: %d pending", pending.size());
        }

        @Override
        protected synchronized void onNewStream() {
            send(StreamingCommitWorkRequest.newBuilder().setHeader(makeHeader()).build());
            Batcher resendBatcher = new Batcher();
            for (Map.Entry<Long, PendingRequest> entry : pending.entrySet()) {
                if (!resendBatcher.canAccept(entry.getValue())) {
                    resendBatcher.flush();
                }
                resendBatcher.add(entry.getKey(), entry.getValue());
            }
            resendBatcher.flush();
        }

        @Override
        protected boolean hasPendingRequests() {
            return !pending.isEmpty();
        }

        @Override
        protected void onResponse(StreamingCommitResponse response) {
            commitWorkThrottleTimer.stop();

            for (int i = 0; i < response.getRequestIdCount(); ++i) {
                long requestId = response.getRequestId(i);
                PendingRequest done = pending.remove(requestId);
                if (done == null) {
                    LOG.error("Got unknown commit request ID: {}", requestId);
                } else {
                    done.onDone.accept((i < response.getStatusCount()) ? response.getStatus(i) : CommitStatus.OK);
                }
            }
        }

        @Override
        protected void startThrottleTimer() {
            commitWorkThrottleTimer.start();
        }

        @Override
        public boolean commitWorkItem(String computation, WorkItemCommitRequest commitRequest,
                Consumer<CommitStatus> onDone) {
            PendingRequest request = new PendingRequest(computation, commitRequest, onDone);
            if (!batcher.canAccept(request)) {
                return false;
            }
            batcher.add(uniqueId(), request);
            return true;
        }

        @Override
        public void flush() {
            batcher.flush();
        }

        private final void flushInternal(Map<Long, PendingRequest> requests) {
            if (requests.isEmpty()) {
                return;
            }
            if (requests.size() == 1) {
                Map.Entry<Long, PendingRequest> elem = requests.entrySet().iterator().next();
                if (elem.getValue().request.getSerializedSize() > COMMIT_STREAM_CHUNK_SIZE) {
                    issueMultiChunkRequest(elem.getKey(), elem.getValue());
                } else {
                    issueSingleRequest(elem.getKey(), elem.getValue());
                }
            } else {
                issueBatchedRequest(requests);
            }
            requests.clear();
        }

        private void issueSingleRequest(final long id, PendingRequest pendingRequest) {
            StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder();
            requestBuilder.addCommitChunkBuilder().setComputationId(pendingRequest.computation).setRequestId(id)
                    .setShardingKey(pendingRequest.request.getShardingKey())
                    .setSerializedWorkItemCommit(pendingRequest.request.toByteString());
            StreamingCommitWorkRequest chunk = requestBuilder.build();
            try {
                synchronized (this) {
                    pending.put(id, pendingRequest);
                    send(chunk);
                }
            } catch (IllegalStateException e) {
                // Stream was broken, request will be retried when stream is reopened.
            }
        }

        private void issueBatchedRequest(Map<Long, PendingRequest> requests) {
            StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder();
            String lastComputation = null;
            for (Map.Entry<Long, PendingRequest> entry : requests.entrySet()) {
                PendingRequest request = entry.getValue();
                StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder();
                if (lastComputation == null || !lastComputation.equals(request.computation)) {
                    chunkBuilder.setComputationId(request.computation);
                    lastComputation = request.computation;
                }
                chunkBuilder.setRequestId(entry.getKey());
                chunkBuilder.setShardingKey(request.request.getShardingKey());
                chunkBuilder.setSerializedWorkItemCommit(request.request.toByteString());
            }
            StreamingCommitWorkRequest request = requestBuilder.build();
            try {
                synchronized (this) {
                    pending.putAll(requests);
                    send(request);
                }
            } catch (IllegalStateException e) {
                // Stream was broken, request will be retried when stream is reopened.
            }
        }

        private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest) {
            Preconditions.checkNotNull(pendingRequest.computation);
            final ByteString serializedCommit = pendingRequest.request.toByteString();

            synchronized (this) {
                pending.put(id, pendingRequest);
                for (int i = 0; i < serializedCommit.size(); i += COMMIT_STREAM_CHUNK_SIZE) {
                    int end = i + COMMIT_STREAM_CHUNK_SIZE;
                    ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size()));

                    StreamingCommitRequestChunk.Builder chunkBuilder = StreamingCommitRequestChunk.newBuilder()
                            .setRequestId(id).setSerializedWorkItemCommit(chunk)
                            .setComputationId(pendingRequest.computation)
                            .setShardingKey(pendingRequest.request.getShardingKey());
                    int remaining = serializedCommit.size() - end;
                    if (remaining > 0) {
                        chunkBuilder.setRemainingBytesForWorkItem(remaining);
                    }

                    StreamingCommitWorkRequest requestChunk = StreamingCommitWorkRequest.newBuilder()
                            .addCommitChunk(chunkBuilder).build();
                    try {
                        send(requestChunk);
                    } catch (IllegalStateException e) {
                        // Stream was broken, request will be retried when stream is reopened.
                        break;
                    }
                }
            }
        }
    }

    @FunctionalInterface
    private interface ParseFn<ResponseT> {
        ResponseT parse(InputStream input) throws IOException;
    }

    /** An InputStream that can be dynamically extended with additional InputStreams. */
    @SuppressWarnings("JdkObsolete")
    private static class AppendableInputStream extends InputStream {
        private static final InputStream POISON_PILL = ByteString.EMPTY.newInput();
        private final AtomicBoolean cancelled = new AtomicBoolean(false);
        private final AtomicBoolean complete = new AtomicBoolean(false);
        private final BlockingDeque<InputStream> queue = new LinkedBlockingDeque<>(10);
        private final InputStream stream = new SequenceInputStream(new Enumeration<InputStream>() {
            InputStream current = ByteString.EMPTY.newInput();

            @Override
            public boolean hasMoreElements() {
                if (current != null) {
                    return true;
                }

                try {
                    current = queue.take();
                    if (current != POISON_PILL) {
                        return true;
                    }
                    if (cancelled.get()) {
                        throw new CancellationException();
                    }
                    if (complete.get()) {
                        return false;
                    }
                    throw new IllegalStateException("Got poison pill but stream is not done.");
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    throw new CancellationException();
                }
            }

            @Override
            public InputStream nextElement() {
                if (!hasMoreElements()) {
                    throw new NoSuchElementException();
                }
                InputStream next = current;
                current = null;
                return next;
            }
        });

        /** Appends a new InputStream to the tail of this stream. */
        public synchronized void append(InputStream chunk) {
            try {
                queue.put(chunk);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }

        /** Cancels the stream. Future calls to InputStream methods will throw CancellationException. */
        public synchronized void cancel() {
            cancelled.set(true);
            try {
                // Put the poison pill at the head of the queue to cancel as quickly as possible.
                queue.clear();
                queue.putFirst(POISON_PILL);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }

        /** Signals that no new InputStreams will be added to this stream. */
        public synchronized void complete() {
            complete.set(true);
            try {
                queue.put(POISON_PILL);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }

        @Override
        public int read() throws IOException {
            if (cancelled.get()) {
                throw new CancellationException();
            }
            return stream.read();
        }

        @Override
        public int read(byte[] b, int off, int len) throws IOException {
            if (cancelled.get()) {
                throw new CancellationException();
            }
            return stream.read(b, off, len);
        }

        @Override
        public int available() throws IOException {
            if (cancelled.get()) {
                throw new CancellationException();
            }
            return stream.available();
        }

        @Override
        public void close() throws IOException {
            stream.close();
        }
    }

    /**
     * A stopwatch used to track the amount of time spent throttled due to Resource Exhausted errors.
     * Throttle time is cumulative for all three rpcs types but not for all streams. So if GetWork and
     * CommitWork are both blocked for x, totalTime will be 2x. However, if 2 GetWork streams are both
     * blocked for x totalTime will be x. All methods are thread safe.
     */
    private static class ThrottleTimer {

        // This is -1 if not currently being throttled or the time in
        // milliseconds when throttling for this type started.
        private long startTime = -1;
        // This is the collected total throttle times since the last poll.  Throttle times are
        // reported as a delta so this is cleared whenever it gets reported.
        private long totalTime = 0;

        /**
         * Starts the timer if it has not been started and does nothing if it has already been started.
         */
        public synchronized void start() {
            if (!throttled()) { // This timer is not started yet so start it now.
                startTime = Instant.now().getMillis();
            }
        }

        /** Stops the timer if it has been started and does nothing if it has not been started. */
        public synchronized void stop() {
            if (throttled()) { // This timer has been started already so stop it now.
                totalTime += Instant.now().getMillis() - startTime;
                startTime = -1;
            }
        }

        /** Returns if the specified type is currently being throttled */
        public synchronized boolean throttled() {
            return startTime != -1;
        }

        /** Returns the combined total of all throttle times and resets those times to 0. */
        public synchronized long getAndResetThrottleTime() {
            if (throttled()) {
                stop();
                start();
            }
            long toReturn = totalTime;
            totalTime = 0;
            return toReturn;
        }
    }
}