io.mandrel.transport.thrift.nifty.ThriftClientManager.java Source code

Java tutorial

Introduction

Here is the source code for io.mandrel.transport.thrift.nifty.ThriftClientManager.java

Source

/*
 * Licensed to Mandrel under one or more contributor
 * license agreements. See the NOTICE file distributed with
 * this work for additional information regarding copyright
 * ownership. Mandrel licenses this file to you under
 * the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package io.mandrel.transport.thrift.nifty;

import static com.facebook.nifty.duplex.TTransportPair.fromSeparateTransports;
import static com.facebook.swift.service.ThriftClientConfig.DEFAULT_CONNECT_TIMEOUT;
import static com.facebook.swift.service.ThriftClientConfig.DEFAULT_MAX_FRAME_SIZE;
import static com.facebook.swift.service.ThriftClientConfig.DEFAULT_READ_TIMEOUT;
import static com.facebook.swift.service.ThriftClientConfig.DEFAULT_RECEIVE_TIMEOUT;
import static com.facebook.swift.service.ThriftClientConfig.DEFAULT_WRITE_TIMEOUT;
import static com.google.common.base.Preconditions.checkNotNull;
import static org.apache.thrift.TApplicationException.UNKNOWN_METHOD;
import io.airlift.units.Duration;

import java.io.Closeable;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;

import javax.annotation.Nullable;
import javax.annotation.PreDestroy;
import javax.annotation.concurrent.Immutable;
import javax.annotation.concurrent.ThreadSafe;
import javax.validation.constraints.NotNull;

import org.apache.thrift.TApplicationException;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolException;
import org.apache.thrift.transport.TTransportException;
import org.jboss.netty.channel.Channel;

import com.facebook.nifty.client.ClientRequestContext;
import com.facebook.nifty.client.NiftyClientChannel;
import com.facebook.nifty.client.NiftyClientConnector;
import com.facebook.nifty.client.NiftyClientRequestContext;
import com.facebook.nifty.client.RequestChannel;
import com.facebook.nifty.core.TChannelBufferInputTransport;
import com.facebook.nifty.core.TChannelBufferOutputTransport;
import com.facebook.nifty.duplex.TProtocolPair;
import com.facebook.nifty.duplex.TTransportPair;
import com.facebook.swift.codec.ThriftCodecManager;
import com.facebook.swift.service.ClientContextChain;
import com.facebook.swift.service.CustomClientContextChain;
import com.facebook.swift.service.RuntimeTApplicationException;
import com.facebook.swift.service.RuntimeTException;
import com.facebook.swift.service.RuntimeTProtocolException;
import com.facebook.swift.service.RuntimeTTransportException;
import com.facebook.swift.service.ThriftClientEventHandler;
import com.facebook.swift.service.ThriftMethodHandler;
import com.facebook.swift.service.metadata.ThriftMethodMetadata;
import com.facebook.swift.service.metadata.ThriftServiceMetadata;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.net.HostAndPort;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.inject.Inject;

@ThreadSafe
public class ThriftClientManager implements Closeable {
    public static final String DEFAULT_NAME = "default";
    private static final int SOCKS_DEFAULT_PORT = 1080;

    private final ThriftCodecManager codecManager;
    private final NiftyClient niftyClient;
    private final LoadingCache<TypeAndName, ThriftClientMetadata> clientMetadataCache = CacheBuilder.newBuilder()
            .build(new CacheLoader<TypeAndName, ThriftClientMetadata>() {
                @Override
                public ThriftClientMetadata load(TypeAndName typeAndName) throws Exception {
                    return new ThriftClientMetadata(typeAndName.getType(), typeAndName.getName(), codecManager);
                }
            });

    private final Set<ThriftClientEventHandler> globalEventHandlers;

    public ThriftClientManager() {
        this(new ThriftCodecManager());
    }

    public ThriftClientManager(ClassLoader parent) {
        this(new ThriftCodecManager(parent));
    }

    public ThriftClientManager(ThriftCodecManager codecManager) {
        this(codecManager, new NiftyClient(), ImmutableSet.<ThriftClientEventHandler>of());
    }

    @Inject
    public ThriftClientManager(ThriftCodecManager codecManager, NiftyClient niftyClient,
            Set<ThriftClientEventHandler> globalEventHandlers) {
        this.codecManager = checkNotNull(codecManager, "codecManager is null");
        this.niftyClient = checkNotNull(niftyClient, "niftyClient is null");
        this.globalEventHandlers = checkNotNull(globalEventHandlers, "globalEventHandlers is null");
    }

    public <C extends NiftyClientChannel> ListenableFuture<C> createChannel(NiftyClientConnector<C> connector) {
        return createChannel(connector, DEFAULT_CONNECT_TIMEOUT, DEFAULT_RECEIVE_TIMEOUT, DEFAULT_READ_TIMEOUT,
                DEFAULT_WRITE_TIMEOUT, DEFAULT_MAX_FRAME_SIZE, getDefaultSocksProxy());
    }

    public <C extends NiftyClientChannel> ListenableFuture<C> createChannel(final NiftyClientConnector<C> connector,
            @Nullable final Duration connectTimeout, @Nullable final Duration receiveTimeout,
            @Nullable final Duration readTimeout, @Nullable final Duration writeTimeout, final int maxFrameSize,
            @Nullable HostAndPort socksProxy) {
        final ListenableFuture<C> connectFuture = niftyClient.connectAsync(connector, connectTimeout,
                receiveTimeout, readTimeout, writeTimeout, maxFrameSize, socksProxy);

        return connectFuture;
    }

    public <T, C extends NiftyClientChannel> ListenableFuture<T> createClient(NiftyClientConnector<C> connector,
            Class<T> type) {
        return createClient(connector, type, DEFAULT_CONNECT_TIMEOUT, DEFAULT_RECEIVE_TIMEOUT, DEFAULT_READ_TIMEOUT,
                DEFAULT_WRITE_TIMEOUT, DEFAULT_MAX_FRAME_SIZE, DEFAULT_NAME,
                ImmutableList.<ThriftClientEventHandler>of(), getDefaultSocksProxy());
    }

    /**
     * @deprecated Use
     *             {@link ThriftClientManager#createClient(NiftyClientConnector, Class, Duration, Duration, Duration, Duration, int, String, List, HostAndPort)}
     *             .
     */
    @Deprecated
    public <T, C extends NiftyClientChannel> ListenableFuture<T> createClient(
            final NiftyClientConnector<C> connector, final Class<T> type, @Nullable final Duration connectTimeout,
            @Nullable final Duration readTimeout, @Nullable final Duration writeTimeout, final int maxFrameSize,
            @Nullable final String clientName, final List<? extends ThriftClientEventHandler> eventHandlers,
            @Nullable HostAndPort socksProxy) {
        return createClient(connector, type, connectTimeout, readTimeout, readTimeout, writeTimeout, maxFrameSize,
                clientName, eventHandlers, socksProxy);
    }

    public <T, C extends NiftyClientChannel> ListenableFuture<T> createClient(
            final NiftyClientConnector<C> connector, final Class<T> type, @Nullable final Duration connectTimeout,
            @Nullable final Duration receiveTimeout, @Nullable final Duration readTimeout,
            @Nullable final Duration writeTimeout, final int maxFrameSize, @Nullable final String clientName,
            final List<? extends ThriftClientEventHandler> eventHandlers, @Nullable HostAndPort socksProxy) {
        checkNotNull(connector, "connector is null");
        checkNotNull(type, "type is null");
        checkNotNull(eventHandlers, "eventHandlers is null");

        final ListenableFuture<C> connectFuture = createChannel(connector, connectTimeout, receiveTimeout,
                readTimeout, writeTimeout, maxFrameSize, socksProxy);

        ListenableFuture<T> clientFuture = Futures.transform(connectFuture, new Function<C, T>() {
            @Nullable
            @Override
            public T apply(@NotNull C channel) {
                String name = Strings.isNullOrEmpty(clientName) ? connector.toString() : clientName;

                try {
                    return createClient(channel, type, name, eventHandlers);
                } catch (Throwable t) {
                    // The channel was created successfully, but client creation failed so the
                    // channel must be closed now
                    channel.close();
                    throw t;
                }
            }
        });

        return clientFuture;
    }

    public <T> T createClient(NiftyClientChannel channel, Class<T> type) {
        return createClient(channel, type, DEFAULT_NAME, ImmutableList.<ThriftClientEventHandler>of());
    }

    public <T> T createClient(NiftyClientChannel channel, Class<T> type,
            List<? extends ThriftClientEventHandler> eventHandlers) {
        return createClient(channel, type, DEFAULT_NAME, eventHandlers);
    }

    public <T> T createClient(RequestChannel channel, Class<T> type, String name,
            List<? extends ThriftClientEventHandler> eventHandlers) {
        checkNotNull(channel, "channel is null");
        checkNotNull(type, "type is null");
        checkNotNull(name, "name is null");
        checkNotNull(eventHandlers, "eventHandlers is null");

        ThriftClientMetadata clientMetadata = clientMetadataCache.getUnchecked(new TypeAndName(type, name));

        String clientDescription = clientMetadata.getName() + " " + channel.toString();

        ThriftInvocationHandler handler = new ThriftInvocationHandler(clientDescription, channel,
                clientMetadata.getMethodHandlers(), ImmutableList.<ThriftClientEventHandler>builder()
                        .addAll(globalEventHandlers).addAll(eventHandlers).build());

        return type.cast(
                Proxy.newProxyInstance(type.getClassLoader(), new Class<?>[] { type, Closeable.class }, handler));
    }

    public ThriftClientMetadata getClientMetadata(Class<?> type, String name) {
        return clientMetadataCache.getUnchecked(new TypeAndName(type, name));
    }

    @PreDestroy
    public void close() {
        niftyClient.close();
    }

    public HostAndPort getDefaultSocksProxy() {
        return niftyClient.getDefaultSocksProxyAddress();
    }

    /**
     * Returns the {@link RequestChannel} backing a Swift client
     *
     * @throws IllegalArgumentException
     *             if the client is not a Swift client
     */
    public RequestChannel getRequestChannel(Object client) {
        try {
            InvocationHandler genericHandler = Proxy.getInvocationHandler(client);
            ThriftInvocationHandler thriftHandler = (ThriftInvocationHandler) genericHandler;
            return thriftHandler.getChannel();
        } catch (IllegalArgumentException | ClassCastException e) {
            throw new IllegalArgumentException("Invalid swift client object", e);
        }
    }

    /**
     * Returns the {@link NiftyClientChannel} backing a Swift client
     *
     * @throws IllegalArgumentException
     *             if the client is not using a {@link com.facebook.nifty.client.NiftyClientChannel}
     *
     * @deprecated Use {@link #getRequestChannel} instead, and cast the result to a {@link NiftyClientChannel} if necessary
     */
    public NiftyClientChannel getNiftyChannel(Object client) {
        try {
            return NiftyClientChannel.class.cast(getRequestChannel(client));
        } catch (ClassCastException e) {
            throw new IllegalArgumentException("The swift client uses a channel that is not a NiftyClientChannel",
                    e);
        }
    }

    /**
     * Returns the remote address that a Swift client is connected to
     *
     * @throws IllegalArgumentException
     *             if the client is not a Swift client or is not connected
     *             through an internet socket
     */
    public HostAndPort getRemoteAddress(Object client) {
        NiftyClientChannel niftyChannel = getNiftyChannel(client);

        try {
            Channel nettyChannel = niftyChannel.getNettyChannel();
            SocketAddress address = nettyChannel.getRemoteAddress();
            InetSocketAddress inetAddress = (InetSocketAddress) address;
            return HostAndPort.fromParts(inetAddress.getHostString(), inetAddress.getPort());
        } catch (NullPointerException | ClassCastException e) {
            throw new IllegalArgumentException("Invalid swift client object", e);
        }
    }

    public TProtocol getOutputProtocol(Object client) {
        try {
            InvocationHandler genericHandler = Proxy.getInvocationHandler(client);
            ThriftInvocationHandler thriftHandler = (ThriftInvocationHandler) genericHandler;
            return thriftHandler.getOutputProtocol();
        } catch (IllegalArgumentException | ClassCastException e) {
            throw new IllegalArgumentException("Invalid swift client object", e);
        }
    }

    public TProtocol getInputProtocol(Object client) {
        try {
            InvocationHandler genericHandler = Proxy.getInvocationHandler(client);
            ThriftInvocationHandler thriftHandler = (ThriftInvocationHandler) genericHandler;
            return thriftHandler.getInputProtocol();
        } catch (IllegalArgumentException | ClassCastException e) {
            throw new IllegalArgumentException("Invalid swift client object", e);
        }
    }

    @Immutable
    public static class ThriftClientMetadata {
        private final String clientType;
        private final String clientName;
        private final ThriftServiceMetadata thriftServiceMetadata;
        private final Map<Method, ThriftMethodHandler> methodHandlers;

        private ThriftClientMetadata(Class<?> clientType, String clientName, ThriftCodecManager codecManager) {
            Preconditions.checkNotNull(clientType, "clientType is null");
            Preconditions.checkNotNull(clientName, "clientName is null");
            Preconditions.checkNotNull(codecManager, "codecManager is null");

            this.clientName = clientName;
            thriftServiceMetadata = new ThriftServiceMetadata(clientType, codecManager.getCatalog());
            this.clientType = thriftServiceMetadata.getName();
            ImmutableMap.Builder<Method, ThriftMethodHandler> methods = ImmutableMap.builder();
            for (ThriftMethodMetadata methodMetadata : thriftServiceMetadata.getMethods().values()) {
                ThriftMethodHandler methodHandler = new ThriftMethodHandler(methodMetadata, codecManager);
                methods.put(methodMetadata.getMethod(), methodHandler);
            }
            methodHandlers = methods.build();
        }

        public String getClientType() {
            return clientType;
        }

        public String getClientName() {
            return clientName;
        }

        public String getName() {
            return thriftServiceMetadata.getName();
        }

        public Map<Method, ThriftMethodHandler> getMethodHandlers() {
            return methodHandlers;
        }
    }

    private static class ThriftInvocationHandler implements InvocationHandler {
        private static final Object[] NO_ARGS = new Object[0];
        private final String clientDescription;

        private final RequestChannel channel;

        private final Map<Method, ThriftMethodHandler> methods;
        private final AtomicInteger sequenceId = new AtomicInteger(1);
        private final List<? extends ThriftClientEventHandler> eventHandlers;
        private final TChannelBufferInputTransport inputTransport;
        private final TChannelBufferOutputTransport outputTransport;
        private final TProtocol inputProtocol;
        private final TProtocol outputProtocol;

        private ThriftInvocationHandler(String clientDescription, RequestChannel channel,
                Map<Method, ThriftMethodHandler> methods, List<? extends ThriftClientEventHandler> eventHandlers) {
            this.clientDescription = clientDescription;
            this.channel = channel;
            this.methods = methods;
            this.eventHandlers = eventHandlers;

            this.inputTransport = new TChannelBufferInputTransport();
            this.outputTransport = new TChannelBufferOutputTransport();

            TTransportPair transportPair = fromSeparateTransports(this.inputTransport, this.outputTransport);
            TProtocolPair protocolPair = channel.getProtocolFactory().getProtocolPair(transportPair);
            this.inputProtocol = protocolPair.getInputProtocol();
            this.outputProtocol = protocolPair.getOutputProtocol();
        }

        public RequestChannel getChannel() {
            return channel;
        }

        public TProtocol getOutputProtocol() {
            return outputProtocol;
        }

        public TProtocol getInputProtocol() {
            return inputProtocol;
        }

        @Override
        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
            if (method.getDeclaringClass() == Object.class) {
                switch (method.getName()) {
                case "toString":
                    return clientDescription;
                case "equals":
                    return equals(Proxy.getInvocationHandler(args[0]));
                case "hashCode":
                    return hashCode();
                default:
                    throw new UnsupportedOperationException();
                }
            }

            if (args == null) {
                args = NO_ARGS;
            }

            if (args.length == 0 && "close".equals(method.getName())) {
                channel.close();
                return null;
            }

            ThriftMethodHandler methodHandler = methods.get(method);

            try {
                if (methodHandler == null) {
                    throw new TApplicationException(UNKNOWN_METHOD, "Unknown method : '" + method + "'");
                }

                if (channel.hasError()) {
                    throw new TTransportException(channel.getError());
                }

                SocketAddress remoteAddress = null;
                // Can only get remote address if this is a nifty channel, plain RequestChannel does
                // not support it
                if (channel instanceof NiftyClientChannel) {
                    NiftyClientChannel niftyClientChannel = (NiftyClientChannel) channel;
                    remoteAddress = niftyClientChannel.getNettyChannel().getRemoteAddress();
                }

                ClientRequestContext requestContext = new NiftyClientRequestContext(getInputProtocol(),
                        getOutputProtocol(), channel, remoteAddress);
                ClientContextChain context = new CustomClientContextChain(eventHandlers,
                        methodHandler.getQualifiedName(), requestContext);
                return methodHandler.invoke(channel, inputTransport, outputTransport, inputProtocol, outputProtocol,
                        sequenceId.getAndIncrement(), context, args);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new RuntimeTException("Thread interrupted", new TException(e));
            } catch (TException e) {
                Class<? extends TException> thrownType = e.getClass();

                for (Class<?> exceptionType : method.getExceptionTypes()) {
                    if (exceptionType.isAssignableFrom(thrownType)) {
                        throw e;
                    }
                }

                //noinspection InstanceofCatchParameter
                if (e instanceof TApplicationException) {
                    throw new RuntimeTApplicationException(e.getMessage(), (TApplicationException) e);
                }
                //noinspection InstanceofCatchParameter
                if (e instanceof TProtocolException) {
                    throw new RuntimeTProtocolException(e.getMessage(), (TProtocolException) e);
                }
                //noinspection InstanceofCatchParameter
                if (e instanceof TTransportException) {
                    throw new RuntimeTTransportException(e.getMessage(), (TTransportException) e);
                }
                throw new RuntimeTException(e.getMessage(), e);
            }
        }
    }

    @Immutable
    private static class TypeAndName {
        private final Class<?> type;
        private final String name;

        public TypeAndName(Class<?> type, String name) {
            Preconditions.checkNotNull(type, "type is null");
            Preconditions.checkNotNull(name, "name is null");
            this.type = type;
            this.name = name;
        }

        public Class<?> getType() {
            return type;
        }

        public String getName() {
            return name;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }

            TypeAndName that = (TypeAndName) o;

            if (!name.equals(that.name)) {
                return false;
            }
            if (!type.equals(that.type)) {
                return false;
            }

            return true;
        }

        @Override
        public int hashCode() {
            int result = type.hashCode();
            result = 31 * result + name.hashCode();
            return result;
        }

        @Override
        public String toString() {
            final StringBuilder sb = new StringBuilder();
            sb.append("TypeAndName");
            sb.append("{type=").append(type);
            sb.append(", name='").append(name).append('\'');
            sb.append('}');
            return sb.toString();
        }
    }
}