org.apache.hadoop.hive.llap.AsyncPbRpcProxy.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.hadoop.hive.llap.AsyncPbRpcProxy.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.hive.llap;

import java.io.IOException;
import java.security.PrivilegedAction;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

import javax.net.SocketFactory;

import org.apache.hadoop.conf.Configuration;
// TODO: LlapNodeId is just a host+port pair; we could make this class more generic.
import org.apache.hadoop.io.retry.RetryPolicies;
import org.apache.hadoop.io.retry.RetryPolicy;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.service.AbstractService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.RemovalListener;
import com.google.common.cache.RemovalNotification;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.protobuf.Message;

public abstract class AsyncPbRpcProxy<ProtocolType, TokenType extends TokenIdentifier> extends AbstractService {
    private static final Logger LOG = LoggerFactory.getLogger(AsyncPbRpcProxy.class);

    private final Cache<String, ProtocolType> hostProxies;
    private final RequestManager requestManager;
    private final RetryPolicy retryPolicy;
    private final SocketFactory socketFactory;
    private final ListeningExecutorService requestManagerExecutor;
    private volatile ListenableFuture<Void> requestManagerFuture;
    private final Token<TokenType> token;
    private final String tokenUser;

    @VisibleForTesting
    public static class RequestManager implements Callable<Void> {
        private final Lock lock = new ReentrantLock();
        private final AtomicBoolean isShutdown = new AtomicBoolean(false);
        private final Condition queueCondition = lock.newCondition();
        private final AtomicBoolean shouldRun = new AtomicBoolean(false);

        private final int maxConcurrentRequestsPerNode;
        private final ListeningExecutorService executor;

        // Tracks new additions via add, while the loop is processing existing ones.
        private final LinkedList<CallableRequest<?, ?>> newRequestList = new LinkedList<>();

        // Tracks existing requests which are cycled through.
        private final LinkedList<CallableRequest<?, ?>> pendingRequests = new LinkedList<>();

        // Tracks requests executing per node
        private final ConcurrentMap<LlapNodeId, AtomicInteger> runningRequests = new ConcurrentHashMap<>();

        // Tracks completed requests pre node
        private final LinkedList<LlapNodeId> completedNodes = new LinkedList<>();

        public RequestManager(int numThreads, int maxPerNode) {
            ExecutorService localExecutor = Executors.newFixedThreadPool(numThreads,
                    new ThreadFactoryBuilder().setNameFormat("TaskCommunicator #%2d").build());
            maxConcurrentRequestsPerNode = maxPerNode;
            executor = MoreExecutors.listeningDecorator(localExecutor);
        }

        @VisibleForTesting
        Set<LlapNodeId> currentLoopDisabledNodes = new HashSet<>();
        @VisibleForTesting
        List<CallableRequest<?, ?>> currentLoopSkippedRequests = new LinkedList<>();

        @Override
        public Void call() {
            // Caches disabled nodes for quicker lookups and ensures a request on a node which was skipped
            // does not go out of order.
            while (!isShutdown.get()) {
                lock.lock();
                try {
                    while (!shouldRun.get()) {
                        queueCondition.await();
                        break; // Break out and try executing.
                    }
                    boolean shouldBreak = process();
                    if (shouldBreak) {
                        break;
                    }
                } catch (InterruptedException e) {
                    handleInterrupt(e);
                    break;
                } finally {
                    lock.unlock();
                }
            }
            LOG.info("CallScheduler loop exiting");
            return null;
        }

        private void handleInterrupt(InterruptedException e) {
            if (isShutdown.get())
                return;
            LOG.warn("RunLoop interrupted without being shutdown first");
            throw new RuntimeException(e);
        }

        /* Add a new request to be executed */
        public void queueRequest(CallableRequest<?, ?> request) {
            synchronized (newRequestList) {
                newRequestList.add(request);
                shouldRun.set(true);
            }
            notifyRunLoop();
        }

        /* Indicates a request has completed on a node */
        public void requestFinished(LlapNodeId nodeId) {
            synchronized (completedNodes) {
                completedNodes.add(nodeId);
                shouldRun.set(true);
            }
            notifyRunLoop();
        }

        public void shutdown() {
            if (!isShutdown.getAndSet(true)) {
                executor.shutdownNow();
                notifyRunLoop();
            }
        }

        @VisibleForTesting
        <T extends Message, U extends Message> void submitToExecutor(CallableRequest<T, U> request,
                LlapNodeId nodeId) {
            ListenableFuture<U> future = executor.submit(request);
            Futures.addCallback(future, new ResponseCallback<U>(request.getCallback(), nodeId, this));
        }

        @VisibleForTesting
        boolean process() throws InterruptedException {
            if (isShutdown.get()) {
                return true;
            }
            currentLoopDisabledNodes.clear();
            currentLoopSkippedRequests.clear();

            // Set to false to block the next loop. This must be called before draining the lists,
            // otherwise an add/completion after draining the lists but before setting it to false,
            // will not trigger a run. May cause one unnecessary run if an add comes in before drain.
            // drain list. add request (setTrue). setFalse needs to be avoided.
            // TODO: why CAS if the result is not checked?
            shouldRun.compareAndSet(true, false);
            // Drain any calls which may have come in during the last execution of the loop.
            drainNewRequestList(); // Locks newRequestList
            drainCompletedNodes(); // Locks completedNodes

            Iterator<CallableRequest<?, ?>> iterator = pendingRequests.iterator();
            while (iterator.hasNext()) {
                CallableRequest<?, ?> request = iterator.next();
                iterator.remove();
                LlapNodeId nodeId;
                try {
                    nodeId = request.getNodeId();
                } catch (InterruptedException e) {
                    throw e;
                } catch (Exception e) {
                    request.getCallback().indicateError(e);
                    continue;
                }

                if (canRunForNode(nodeId, currentLoopDisabledNodes)) {
                    submitToExecutor(request, nodeId);
                } else {
                    currentLoopDisabledNodes.add(nodeId);
                    currentLoopSkippedRequests.add(request);
                }
            }
            // Tried scheduling everything that could be scheduled in this loop.
            pendingRequests.addAll(0, currentLoopSkippedRequests);
            return false;
        }

        private void drainNewRequestList() {
            synchronized (newRequestList) {
                if (!newRequestList.isEmpty()) {
                    pendingRequests.addAll(newRequestList);
                    newRequestList.clear();
                }
            }
        }

        private void drainCompletedNodes() {
            synchronized (completedNodes) {
                if (!completedNodes.isEmpty()) {
                    for (LlapNodeId nodeId : completedNodes) {
                        runningRequests.get(nodeId).decrementAndGet();
                    }
                }
                completedNodes.clear();
            }
        }

        private boolean canRunForNode(LlapNodeId nodeId, Set<LlapNodeId> currentRunDisabledNodes) {
            if (currentRunDisabledNodes.contains(nodeId)) {
                return false;
            } else {
                AtomicInteger count = runningRequests.get(nodeId);
                if (count == null) {
                    count = new AtomicInteger(0);
                    AtomicInteger old = runningRequests.putIfAbsent(nodeId, count);
                    count = old != null ? old : count;
                }
                if (count.incrementAndGet() <= maxConcurrentRequestsPerNode) {
                    return true;
                } else {
                    count.decrementAndGet();
                    return false;
                }
            }
        }

        private void notifyRunLoop() {
            lock.lock();
            try {
                queueCondition.signal();
            } finally {
                lock.unlock();
            }
        }
    }

    @Override
    public void serviceStart() {
        requestManagerFuture = requestManagerExecutor.submit(requestManager);
        Futures.addCallback(requestManagerFuture, new FutureCallback<Void>() {
            @Override
            public void onSuccess(Void result) {
                LOG.info("RequestManager shutdown");
            }

            @Override
            public void onFailure(Throwable t) {
                if (!(t instanceof CancellationException)) {
                    LOG.warn("RequestManager shutdown with error", t);
                }
            }
        });
    }

    @Override
    public void serviceStop() {
        if (requestManagerFuture != null) {
            requestManager.shutdown();
            requestManagerFuture.cancel(true);
        }
        requestManagerExecutor.shutdown();
    }

    protected final void queueRequest(CallableRequest<?, ?> request) {
        requestManager.queueRequest(request);
    }

    private static final class ResponseCallback<TYPE extends Message> implements FutureCallback<TYPE> {

        private final ExecuteRequestCallback<TYPE> callback;
        private final LlapNodeId nodeId;
        private final RequestManager requestManager;

        public ResponseCallback(ExecuteRequestCallback<TYPE> callback, LlapNodeId nodeId,
                RequestManager requestManager) {
            this.callback = callback;
            this.nodeId = nodeId;
            this.requestManager = requestManager;
        }

        @Override
        public void onSuccess(TYPE result) {
            try {
                callback.setResponse(result);
            } finally {
                requestManager.requestFinished(nodeId);
            }
        }

        @Override
        public void onFailure(Throwable t) {
            try {
                callback.indicateError(t);
            } finally {
                requestManager.requestFinished(nodeId);
            }
        }
    }

    @VisibleForTesting
    protected static abstract class CallableRequest<REQUEST extends Message, RESPONSE extends Message>
            implements Callable<RESPONSE> {
        protected final ExecuteRequestCallback<RESPONSE> callback;
        protected final REQUEST request;

        protected CallableRequest(REQUEST request, ExecuteRequestCallback<RESPONSE> callback) {
            this.request = request;
            this.callback = callback;
        }

        public abstract LlapNodeId getNodeId() throws Exception;

        public ExecuteRequestCallback<RESPONSE> getCallback() {
            return callback;
        }

        public abstract RESPONSE call() throws Exception;
    }

    @VisibleForTesting
    protected static abstract class NodeCallableRequest<REQUEST extends Message, RESPONSE extends Message>
            extends CallableRequest<REQUEST, RESPONSE> {
        protected final LlapNodeId nodeId;

        protected NodeCallableRequest(LlapNodeId nodeId, REQUEST request,
                ExecuteRequestCallback<RESPONSE> callback) {
            super(request, callback);
            this.nodeId = nodeId;
        }

        @Override
        public LlapNodeId getNodeId() {
            return nodeId;
        }
    }

    public interface ExecuteRequestCallback<T extends Message> {
        void setResponse(T response);

        void indicateError(Throwable t);
    }

    public AsyncPbRpcProxy(String name, int numThreads, Configuration conf, Token<TokenType> token,
            long connectionTimeoutMs, long retrySleepMs, int expectedNodes, int maxPerNode) {
        super(name);
        // Note: we may make size/etc. configurable later.
        CacheBuilder<String, ProtocolType> cb = CacheBuilder.newBuilder().expireAfterAccess(1, TimeUnit.HOURS)
                .removalListener(new RemovalListener<String, ProtocolType>() {
                    @Override
                    public void onRemoval(RemovalNotification<String, ProtocolType> arg) {
                        if (arg == null)
                            return;
                        shutdownProtocolImpl(arg.getValue());
                    }
                });
        if (expectedNodes > 0) {
            cb.maximumSize(expectedNodes * 2);
        }
        this.hostProxies = cb.build();
        this.socketFactory = NetUtils.getDefaultSocketFactory(conf);
        this.token = token;
        if (token != null) {
            String tokenUser = getTokenUser(token);
            if (tokenUser == null) {
                try {
                    tokenUser = UserGroupInformation.getCurrentUser().getShortUserName();
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
                LOG.warn("Cannot determine token user from the token; using {}", tokenUser);
            }
            this.tokenUser = tokenUser;
        } else {
            this.tokenUser = null;
        }

        this.retryPolicy = RetryPolicies.retryUpToMaximumTimeWithFixedSleep(connectionTimeoutMs, retrySleepMs,
                TimeUnit.MILLISECONDS);

        this.requestManager = new RequestManager(numThreads, maxPerNode);
        ExecutorService localExecutor = Executors.newFixedThreadPool(1,
                new ThreadFactoryBuilder().setNameFormat("RequestManagerExecutor").build());
        this.requestManagerExecutor = MoreExecutors.listeningDecorator(localExecutor);

        LOG.info("Setting up AsyncPbRpcProxy with" + "numThreads=" + numThreads + "retryTime(millis)="
                + connectionTimeoutMs + "retrySleep(millis)=" + retrySleepMs);
    }

    /**
     * @param nodeId Hostname + post.
     * @param nodeToken A custom node token. If not specified, the default token is used.
     * @return the protocol client implementation for the node.
     */
    protected final ProtocolType getProxy(final LlapNodeId nodeId, final Token<TokenType> nodeToken) {
        String hostId = getHostIdentifier(nodeId.getHostname(), nodeId.getPort());
        try {
            return hostProxies.get(hostId, new Callable<ProtocolType>() {
                @Override
                public ProtocolType call() throws Exception {
                    return createProxy(nodeId, nodeToken);
                }
            });
        } catch (ExecutionException e) {
            throw new RuntimeException(e);
        }
    }

    private ProtocolType createProxy(final LlapNodeId nodeId, Token<TokenType> nodeToken) throws IOException {
        if (nodeToken == null && this.token == null) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Creating a client without a token for " + nodeId);
            }
            return createProtocolImpl(getConfig(), nodeId.getHostname(), nodeId.getPort(), null, retryPolicy,
                    socketFactory);
        }
        if (this.token != null && this.tokenUser == null) {
            throw new AssertionError("Invalid internal state from " + this.token);
        }
        // Either the token should be passed in here, or in ctor.
        String tokenUser = this.tokenUser == null ? getTokenUser(nodeToken) : this.tokenUser;
        if (tokenUser == null) {
            tokenUser = UserGroupInformation.getCurrentUser().getShortUserName();
            LOG.warn("Cannot determine token user for UGI; using {}", tokenUser);
        }
        final UserGroupInformation ugi = UserGroupInformation.createRemoteUser(tokenUser);
        // Clone the token as we'd need to set the service to the one we are talking to.
        if (nodeToken == null) {
            nodeToken = new Token<TokenType>(token);
        }
        SecurityUtil.setTokenService(nodeToken,
                NetUtils.createSocketAddrForHost(nodeId.getHostname(), nodeId.getPort()));
        ugi.addToken(nodeToken);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Creating a client for " + nodeId + "; the token is " + nodeToken);
        }
        return ugi.doAs(new PrivilegedAction<ProtocolType>() {
            @Override
            public ProtocolType run() {
                return createProtocolImpl(getConfig(), nodeId.getHostname(), nodeId.getPort(), ugi, retryPolicy,
                        socketFactory);
            }
        });
    }

    private String getHostIdentifier(String hostname, int port) {
        return hostname + ":" + port;
    }

    protected abstract ProtocolType createProtocolImpl(Configuration config, String hostname, int port,
            UserGroupInformation ugi, RetryPolicy retryPolicy, SocketFactory socketFactory);

    protected abstract void shutdownProtocolImpl(ProtocolType proxy);

    protected abstract String getTokenUser(Token<TokenType> token);
}