Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.alibaba.wasp.ipc; import com.alibaba.wasp.FConstants; import com.alibaba.wasp.ipc.NettyTransportCodec.NettyDataPack; import com.alibaba.wasp.ipc.NettyTransportCodec.NettyFrameDecoder; import com.alibaba.wasp.ipc.NettyTransportCodec.NettyFrameEncoder; import com.alibaba.wasp.protobuf.generated.RPCProtos; import com.alibaba.wasp.protobuf.generated.RPCProtos.*; import com.alibaba.wasp.protobuf.generated.RPCProtos.RpcResponseHeader.Status; import com.alibaba.wasp.protobuf.generated.Tracing.RPCTInfo; import com.alibaba.wasp.util.ByteBufferInputStream; import com.alibaba.wasp.util.ByteBufferOutputStream; import com.google.protobuf.Message; import com.google.protobuf.Message.Builder; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.ipc.RemoteException; import org.cloudera.htrace.Span; import org.cloudera.htrace.Trace; import org.jboss.netty.bootstrap.ClientBootstrap; import org.jboss.netty.channel.*; import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory; import org.jboss.netty.handler.timeout.ReadTimeoutHandler; import org.jboss.netty.handler.timeout.WriteTimeoutHandler; import org.jboss.netty.util.HashedWheelTimer; import org.jboss.netty.util.Timer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.EOFException; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.ReentrantReadWriteLock; /** * A Netty-based {@link Transceiver} implementation. */ public class NettyTransceiver extends Transceiver { public static final String NETTY_CONNECT_TIMEOUT_OPTION = "connectTimeoutMillis"; public static final String NETTY_TCP_NODELAY_OPTION = "tcpNoDelay"; public static final boolean DEFAULT_TCP_NODELAY_VALUE = true; private static final Logger LOG = LoggerFactory.getLogger(NettyTransceiver.class.getName()); private final AtomicInteger serialGenerator = new AtomicInteger(0); private final Map<Integer, Callback<List<ByteBuffer>>> requests = new ConcurrentHashMap<Integer, Callback<List<ByteBuffer>>>(); private final ChannelFactory channelFactory; private final long connectTimeoutMillis; private final ClientBootstrap bootstrap; private final InetSocketAddress remoteAddr; private boolean connectionEstablished = false; private Object connected = new Object(); private Map<String, Object> nettyClientBootstrapOptions; private int refCount = 1; private Configuration conf; private Timer timer; /** * Read lock must be acquired whenever using non-final state. Write lock must * be acquired whenever modifying state. */ private final ReentrantReadWriteLock stateLock = new ReentrantReadWriteLock(); private Channel channel; // Synchronized on stateLock NettyTransceiver() { channelFactory = null; connectTimeoutMillis = 0L; bootstrap = null; remoteAddr = null; } /** * Creates a NettyTransceiver, and attempts to connect to the given address. * {@link #DEFAULT_CONNECTION_TIMEOUT_MILLIS} is used for the connection * timeout. * * @param addr * the address to connect to. * @throws java.io.IOException * if an error occurs connecting to the given address. */ public NettyTransceiver(InetSocketAddress addr) throws IOException { this(addr, FConstants.DEFAULT_CONNECTION_TIMEOUT_MILLIS); } /** * Creates a NettyTransceiver, and attempts to connect to the given address. * * @param addr * the address to connect to. * @param connectTimeoutMillis * maximum amount of time to wait for connection establishment in * milliseconds, or null to use * {@link #DEFAULT_CONNECTION_TIMEOUT_MILLIS}. * @throws java.io.IOException * if an error occurs connecting to the given address. */ public NettyTransceiver(InetSocketAddress addr, Long connectTimeoutMillis) throws IOException { this(addr, new NioClientSocketChannelFactory( Executors.newCachedThreadPool(new NettyTransceiverThreadFactory( "Wasp " + NettyTransceiver.class.getSimpleName() + " Boss")), Executors.newCachedThreadPool(new NettyTransceiverThreadFactory( "Wasp " + NettyTransceiver.class.getSimpleName() + " I/O Worker"))), connectTimeoutMillis); } /** * Creates a NettyTransceiver, and attempts to connect to the given address. * {@link #DEFAULT_CONNECTION_TIMEOUT_MILLIS} is used for the connection * timeout. * * @param addr * the address to connect to. * @param channelFactory * the factory to use to create a new Netty Channel. * @throws java.io.IOException * if an error occurs connecting to the given address. */ public NettyTransceiver(InetSocketAddress addr, ChannelFactory channelFactory) throws IOException { this(addr, channelFactory, buildDefaultBootstrapOptions(null)); } /** * Creates a NettyTransceiver, and attempts to connect to the given address. * * @param addr * the address to connect to. * @param channelFactory * the factory to use to create a new Netty Channel. * @param connectTimeoutMillis * maximum amount of time to wait for connection establishment in * milliseconds, or null to use * {@link #DEFAULT_CONNECTION_TIMEOUT_MILLIS}. * @throws java.io.IOException * if an error occurs connecting to the given address. */ public NettyTransceiver(InetSocketAddress addr, ChannelFactory channelFactory, Long connectTimeoutMillis) throws IOException { this(addr, channelFactory, buildDefaultBootstrapOptions(connectTimeoutMillis)); } /** * Creates a NettyTransceiver, and attempts to connect to the given address. * It is strongly recommended that the {@link #NETTY_CONNECT_TIMEOUT_OPTION} * option be set to a reasonable timeout value (a Long value in milliseconds) * to prevent connect/disconnect attempts from hanging indefinitely. It is * also recommended that the {@link #NETTY_TCP_NODELAY_OPTION} option be set * to true to minimize RPC latency. * * @param addr * the address to connect to. * @param channelFactory * the factory to use to create a new Netty Channel. * @param nettyClientBootstrapOptions * map of Netty ClientBootstrap options to use. * @throws java.io.IOException * if an error occurs connecting to the given address. */ public NettyTransceiver(InetSocketAddress addr, ChannelFactory channelFactory, Map<String, Object> nettyClientBootstrapOptions) throws IOException { if (channelFactory == null) { throw new NullPointerException("channelFactory is null"); } // Set up. this.channelFactory = channelFactory; this.connectTimeoutMillis = (Long) nettyClientBootstrapOptions.get(NETTY_CONNECT_TIMEOUT_OPTION); bootstrap = new ClientBootstrap(channelFactory); remoteAddr = addr; this.nettyClientBootstrapOptions = nettyClientBootstrapOptions; } public synchronized void connect() throws IOException { if (connectionEstablished) { return; } // Configure the event pipeline factory. bootstrap.setPipelineFactory(new ChannelPipelineFactory() { @Override public ChannelPipeline getPipeline() throws Exception { ChannelPipeline p = Channels.pipeline(); timer = new HashedWheelTimer(); p.addLast("frameDecoder", new NettyFrameDecoder()); p.addLast("frameEncoder", new NettyFrameEncoder()); p.addLast("handler", new NettyClientWaspHandler()); p.addLast("readTimeout", new ReadTimeoutHandler(timer, NettyTransceiver.this.conf.getInt( FConstants.CONNECTION_READ_TIMEOUT_SEC, FConstants.DEFAULT_CONNECTION_READ_TIMEOUT_SEC))); p.addLast("writeTimeout", new WriteTimeoutHandler(timer, NettyTransceiver.this.conf.getInt( FConstants.CONNECTION_WRITE_TIMEOUT_SEC, FConstants.DEFAULT_CONNECTION_WRITE_TIMEOUT_SEC))); return p; } }); if (nettyClientBootstrapOptions != null) { LOG.debug("Using Netty bootstrap options: " + nettyClientBootstrapOptions); bootstrap.setOptions(nettyClientBootstrapOptions); } // Make a new connection. stateLock.readLock().lock(); try { getChannel(); } finally { stateLock.readLock().unlock(); } connectionEstablished = true; } /** * Creates the default options map for the Netty ClientBootstrap. * * @param connectTimeoutMillis * connection timeout in milliseconds, or null if no timeout is * desired. * @return the map of Netty bootstrap options. */ private static Map<String, Object> buildDefaultBootstrapOptions(Long connectTimeoutMillis) { Map<String, Object> options = new HashMap<String, Object>(2); options.put(NETTY_TCP_NODELAY_OPTION, DEFAULT_TCP_NODELAY_VALUE); options.put(NETTY_CONNECT_TIMEOUT_OPTION, connectTimeoutMillis == null ? FConstants.DEFAULT_CONNECTION_TIMEOUT_MILLIS : connectTimeoutMillis); return options; } /** * Tests whether the given channel is ready for writing. * * @return true if the channel is open and ready; false otherwise. */ private static boolean isChannelReady(Channel channel) { return (channel != null) && channel.isOpen() && channel.isBound() && channel.isConnected(); } /** * Gets the Netty channel. If the channel is not connected, first attempts to * connect. NOTE: The stateLock read lock *must* be acquired before calling * this method. * * @return the Netty channel * @throws java.io.IOException * if an error occurs connecting the channel. */ private Channel getChannel() throws IOException { if (!isChannelReady(channel)) { // Need to reconnect // Upgrade to write lock stateLock.readLock().unlock(); stateLock.writeLock().lock(); try { if (!isChannelReady(channel)) { LOG.debug("Connecting to " + remoteAddr); ChannelFuture channelFuture = bootstrap.connect(remoteAddr); channelFuture.addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { synchronized (connected) { if (future.isSuccess()) { LOG.info("Successfully connected to bookie: " + remoteAddr); channel = future.getChannel(); } else { channel = null; throw new IOException("Error connecting to " + remoteAddr, future.getCause()); } connected.notify(); } } }); try { synchronized (connected) { connected.wait(connectTimeoutMillis); } if (channel == null) { throw new IOException("Error connecting to " + remoteAddr); } } catch (InterruptedException e) { throw new InternalError(); } } } finally { // Downgrade to read lock: stateLock.readLock().lock(); stateLock.writeLock().unlock(); } } return channel; } /** * Closes the connection to the remote peer if connected. * * @param awaitCompletion * if true, will block until the close has completed. * @param cancelPendingRequests * if true, will drain the requests map and send an IOException to * all Callbacks. * @param cause * if non-null and cancelPendingRequests is true, this Throwable will * be passed to all Callbacks. */ private void disconnect(boolean awaitCompletion, boolean cancelPendingRequests, Throwable cause) { Channel channelToClose = null; Map<Integer, Callback<List<ByteBuffer>>> requestsToCancel = null; boolean stateReadLockHeld = stateLock.getReadHoldCount() != 0; if (stateReadLockHeld) { stateLock.readLock().unlock(); } stateLock.writeLock().lock(); try { if (channel != null) { if (cause != null) { LOG.debug("Disconnecting from " + remoteAddr, cause); } else { LOG.debug("Disconnecting from " + remoteAddr); } channelToClose = channel; channel = null; if (cancelPendingRequests) { // Remove all pending requests (will be canceled after relinquishing // write lock). requestsToCancel = new ConcurrentHashMap<Integer, Callback<List<ByteBuffer>>>(requests); requests.clear(); } } } finally { if (stateReadLockHeld) { stateLock.readLock().lock(); } stateLock.writeLock().unlock(); } // Cancel any pending requests by sending errors to the callbacks: if ((requestsToCancel != null) && !requestsToCancel.isEmpty()) { LOG.debug("Removing " + requestsToCancel.size() + " pending request(s)."); for (Callback<List<ByteBuffer>> request : requestsToCancel.values()) { request.handleError( cause != null ? cause : new IOException(getClass().getSimpleName() + " closed")); } } // Close the channel: if (channelToClose != null) { ChannelFuture closeFuture = channelToClose.close(); timer.stop(); if (awaitCompletion && (closeFuture != null)) { closeFuture.awaitUninterruptibly(connectTimeoutMillis); } } this.connectionEstablished = false; } /** * Netty channels are thread-safe, so there is no need to acquire locks. This * method is a no-op. */ @Override public void lockChannel() { } /** * Netty channels are thread-safe, so there is no need to acquire locks. This * method is a no-op. */ @Override public void unlockChannel() { } public void close() { try { // Close the connection: disconnect(true, true, null); } finally { // Shut down all thread pools to exit. channelFactory.releaseExternalResources(); } } @Override public String getRemoteName() throws IOException { stateLock.readLock().lock(); try { return getChannel().getRemoteAddress().toString(); } finally { stateLock.readLock().unlock(); } } /** * Make a call, passing <code>param</code>, to the IPC server running at * <code>address</code> which is servicing the <code>protocol</code> protocol, * with the <code>ticket</code> credentials, returning the value. Throws * exceptions if there are network problems or if the remote code threw an * exception. */ public Message call(RpcRequestBody param, InetSocketAddress addr, Class<? extends VersionedProtocol> protocol, int rpcTimeout) throws InterruptedException, IOException { if (!connectionEstablished) { connect(); } ConnectionHeader.Builder builder = ConnectionHeader.newBuilder(); builder.setProtocol(protocol == null ? "" : protocol.getName()); ConnectionHeader connectionHeader = builder.build(); RpcRequestHeader.Builder headerBuilder = RPCProtos.RpcRequestHeader.newBuilder(); if (Trace.isTracing()) { Span s = Trace.currentTrace(); headerBuilder.setTinfo(RPCTInfo.newBuilder().setParentId(s.getSpanId()).setTraceId(s.getTraceId())); } RpcRequestHeader rpcHeader = headerBuilder.build(); ByteBufferOutputStream bbo = new ByteBufferOutputStream(); connectionHeader.writeDelimitedTo(bbo); rpcHeader.writeDelimitedTo(bbo); param.writeDelimitedTo(bbo); List<ByteBuffer> res = transceive(bbo.getBufferList()); return processResponse(res, protocol, param); } private Message processResponse(List<ByteBuffer> res, Class<? extends VersionedProtocol> protocol, RpcRequestBody param) throws IOException { ByteBufferInputStream in = new ByteBufferInputStream(res); try { // See NettyServer.prepareResponse for where we write out the response. // It writes the call.id (int), a boolean signifying any error (and if // so the exception name/trace), and the response bytes // Read the call id. RpcResponseHeader response = RpcResponseHeader.parseDelimitedFrom(in); if (response == null) { // When the stream is closed, protobuf doesn't raise an EOFException, // instead, it returns a null message object. throw new EOFException(); } Status status = response.getStatus(); if (status == Status.SUCCESS) { Message rpcResponseType; try { rpcResponseType = ProtobufRpcEngine.Invoker.getReturnProtoType( ProtobufRpcEngine.Server.getMethod(protocol, param.getMethodName())); } catch (Exception e) { throw new RuntimeException(e); // local exception } Builder builder = rpcResponseType.newBuilderForType(); builder.mergeDelimitedFrom(in); Message value = builder.build(); return value; } else if (status == Status.ERROR) { RpcException exceptionResponse = RpcException.parseDelimitedFrom(in); RemoteException remoteException = new RemoteException(exceptionResponse.getExceptionName(), exceptionResponse.getStackTrace()); throw remoteException.unwrapRemoteException(); } else if (status == Status.FATAL) { RpcException exceptionResponse = RpcException.parseDelimitedFrom(in); // Close the connection LOG.error("Fatal Exception.", exceptionResponse); RemoteException remoteException = new RemoteException(exceptionResponse.getExceptionName(), exceptionResponse.getStackTrace()); throw remoteException.unwrapRemoteException(); } else { throw new IOException("What happened?"); } } catch (Exception e) { if (e instanceof RemoteException) { ((RemoteException) e).unwrapRemoteException(); } if (e instanceof IOException) { throw (IOException) e; } else { throw new IOException(e); } } } /** * Override as non-synchronized method because the method is thread safe. */ @Override public List<ByteBuffer> transceive(List<ByteBuffer> request) throws IOException { try { CallFuture<List<ByteBuffer>> transceiverFuture = new CallFuture<List<ByteBuffer>>(); transceive(request, transceiverFuture); return transceiverFuture.get(); } catch (InterruptedException e) { LOG.info("failed to get the response", e); throw new IOException(e); } catch (ExecutionException e) { LOG.warn("failed to get the response", e); throw new IOException(e); } } @Override public void transceive(List<ByteBuffer> request, Callback<List<ByteBuffer>> callback) throws IOException { stateLock.readLock().lock(); try { int serial = serialGenerator.incrementAndGet(); NettyDataPack dataPack = new NettyDataPack(serial, request); requests.put(serial, callback); writeDataPack(dataPack); } finally { stateLock.readLock().unlock(); } } @Override public void writeBuffers(List<ByteBuffer> buffers) throws IOException { stateLock.readLock().lock(); try { writeDataPack(new NettyDataPack(serialGenerator.incrementAndGet(), buffers)); } finally { stateLock.readLock().unlock(); } } /** * Writes a NettyDataPack, reconnecting to the remote peer if necessary. NOTE: * The stateLock read lock *must* be acquired before calling this method. * * @param dataPack * the data pack to write. * @throws java.io.IOException * if an error occurs connecting to the remote peer. */ private void writeDataPack(NettyDataPack dataPack) throws IOException { getChannel().write(dataPack); } @Override public List<ByteBuffer> readBuffers() throws IOException { throw new UnsupportedOperationException(); } /** * Wasp client handler for the Netty transport */ class NettyClientWaspHandler extends SimpleChannelUpstreamHandler { @Override public void handleUpstream(ChannelHandlerContext ctx, ChannelEvent e) throws Exception { if (e instanceof ChannelStateEvent) { LOG.debug(e.toString()); ChannelStateEvent cse = (ChannelStateEvent) e; if ((cse.getState() == ChannelState.OPEN) && (Boolean.FALSE.equals(cse.getValue()))) { // Server closed connection; disconnect client side LOG.debug("Remote peer " + remoteAddr + " closed connection."); disconnect(false, true, null); } } super.handleUpstream(ctx, e); } @Override public void channelOpen(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { // channel = e.getChannel(); super.channelOpen(ctx, e); } @Override public void messageReceived(ChannelHandlerContext ctx, final MessageEvent e) { NettyDataPack dataPack = (NettyDataPack) e.getMessage(); Callback<List<ByteBuffer>> callback = requests.get(dataPack.getSerial()); if (callback == null) { throw new RuntimeException("Missing previous call info"); } try { callback.handleResult(dataPack.getDatas()); } finally { requests.remove(dataPack.getSerial()); } } @Override public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) { disconnect(false, true, e.getCause()); } } /** * Creates threads with unique names based on a specified name prefix. */ private static class NettyTransceiverThreadFactory implements ThreadFactory { private final AtomicInteger threadId = new AtomicInteger(0); private final String prefix; /** * Creates a NettyTransceiverThreadFactory that creates threads with the * specified name. * * @param prefix * the name prefix to use for all threads created by this * ThreadFactory. A unique ID will be appended to this prefix to * form the final thread name. */ public NettyTransceiverThreadFactory(String prefix) { this.prefix = prefix; } @Override public Thread newThread(Runnable r) { Thread thread = new Thread(r); thread.setName(prefix + " " + threadId.incrementAndGet()); return thread; } } /** * Increment this client's reference count * */ synchronized void incCount() { refCount++; } /** * Decrement this client's reference count * */ synchronized void decCount() { refCount--; } /** * Return if this client has no reference * * @return true if this client has no reference; false otherwise */ synchronized boolean isZeroReference() { return refCount == 0; } /** * @return the remoteAddr */ public InetSocketAddress getRemoteAddr() { return remoteAddr; } /** * @see org.apache.hadoop.conf.Configurable#getConf() */ @Override public Configuration getConf() { return this.conf; } /** * @see org.apache.hadoop.conf.Configurable#setConf(org.apache.hadoop.conf.Configuration) */ @Override public void setConf(Configuration conf) { this.conf = conf; } }