org.apache.flume.source.ThriftSource.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.flume.source.ThriftSource.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.flume.source;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.flume.ChannelException;
import org.apache.flume.Context;
import org.apache.flume.Event;
import org.apache.flume.EventDrivenSource;
import org.apache.flume.FlumeException;
import org.apache.flume.auth.FlumeAuthenticationUtil;
import org.apache.flume.auth.FlumeAuthenticator;
import org.apache.flume.conf.Configurable;
import org.apache.flume.event.EventBuilder;
import org.apache.flume.instrumentation.SourceCounter;
import org.apache.flume.thrift.Status;
import org.apache.flume.thrift.ThriftSourceProtocol;
import org.apache.flume.thrift.ThriftFlumeEvent;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TCompactProtocol;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.server.TNonblockingServer;
import org.apache.thrift.server.TServer;
import org.apache.thrift.server.TThreadPoolServer;
import org.apache.thrift.transport.TFastFramedTransport;
import org.apache.thrift.transport.TNonblockingServerSocket;
import org.apache.thrift.transport.TNonblockingServerTransport;
import org.apache.thrift.transport.TServerSocket;
import org.apache.thrift.transport.TServerTransport;
import org.apache.thrift.transport.TSSLTransportFactory;
import org.apache.thrift.transport.TTransportFactory;
import org.apache.thrift.transport.TSaslServerTransport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLServerSocket;
import javax.security.sasl.Sasl;
import java.io.FileInputStream;
import java.io.IOException;
import java.lang.reflect.Method;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.security.KeyStore;
import java.security.Security;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.security.PrivilegedAction;

public class ThriftSource extends AbstractSource implements Configurable, EventDrivenSource {

    public static final Logger logger = LoggerFactory.getLogger(ThriftSource.class);
    /**
     * Config param for the maximum number of threads this source should use to
     * handle incoming data.
     */
    public static final String CONFIG_THREADS = "threads";
    /**
     * Config param for the hostname to listen on.
     */
    public static final String CONFIG_BIND = "bind";
    /**
     * Config param for the port to listen on.
     */
    public static final String CONFIG_PORT = "port";
    /**
     * Config param for the thrift protocol to use.
     */
    public static final String CONFIG_PROTOCOL = "protocol";
    public static final String BINARY_PROTOCOL = "binary";
    public static final String COMPACT_PROTOCOL = "compact";

    private static final String SSL_KEY = "ssl";
    private static final String KEYSTORE_KEY = "keystore";
    private static final String KEYSTORE_PASSWORD_KEY = "keystore-password";
    private static final String KEYSTORE_TYPE_KEY = "keystore-type";
    private static final String EXCLUDE_PROTOCOLS = "exclude-protocols";

    private static final String KERBEROS_KEY = "kerberos";
    private static final String AGENT_PRINCIPAL = "agent-principal";
    private static final String AGENT_KEYTAB = "agent-keytab";

    private Integer port;
    private String bindAddress;
    private int maxThreads = 0;
    private SourceCounter sourceCounter;
    private TServer server;
    private ExecutorService servingExecutor;
    private String protocol;
    private String keystore;
    private String keystorePassword;
    private String keystoreType;
    private final List<String> excludeProtocols = new LinkedList<String>();
    private boolean enableSsl = false;
    private boolean enableKerberos = false;
    private String principal;
    private FlumeAuthenticator flumeAuth;

    @Override
    public void configure(Context context) {
        logger.info("Configuring thrift source.");
        port = context.getInteger(CONFIG_PORT);
        Preconditions.checkNotNull(port, "Port must be specified for Thrift " + "Source.");
        bindAddress = context.getString(CONFIG_BIND);
        Preconditions.checkNotNull(bindAddress, "Bind address must be specified " + "for Thrift Source.");

        try {
            maxThreads = context.getInteger(CONFIG_THREADS, 0);
            maxThreads = (maxThreads <= 0) ? Integer.MAX_VALUE : maxThreads;
        } catch (NumberFormatException e) {
            logger.warn("Thrift source\'s \"threads\" property must specify an " + "integer value: "
                    + context.getString(CONFIG_THREADS));
        }

        if (sourceCounter == null) {
            sourceCounter = new SourceCounter(getName());
        }

        protocol = context.getString(CONFIG_PROTOCOL);
        if (protocol == null) {
            // default is to use the compact protocol.
            protocol = COMPACT_PROTOCOL;
        }
        Preconditions.checkArgument(
                (protocol.equalsIgnoreCase(BINARY_PROTOCOL) || protocol.equalsIgnoreCase(COMPACT_PROTOCOL)),
                "binary or compact are the only valid Thrift protocol types to " + "choose from.");

        enableSsl = context.getBoolean(SSL_KEY, false);
        if (enableSsl) {
            keystore = context.getString(KEYSTORE_KEY);
            keystorePassword = context.getString(KEYSTORE_PASSWORD_KEY);
            keystoreType = context.getString(KEYSTORE_TYPE_KEY, "JKS");
            String excludeProtocolsStr = context.getString(EXCLUDE_PROTOCOLS);
            if (excludeProtocolsStr == null) {
                excludeProtocols.add("SSLv3");
            } else {
                excludeProtocols.addAll(Arrays.asList(excludeProtocolsStr.split(" ")));
                if (!excludeProtocols.contains("SSLv3")) {
                    excludeProtocols.add("SSLv3");
                }
            }
            Preconditions.checkNotNull(keystore, KEYSTORE_KEY + " must be specified when SSL is enabled");
            Preconditions.checkNotNull(keystorePassword,
                    KEYSTORE_PASSWORD_KEY + " must be specified when SSL is enabled");
            try {
                KeyStore ks = KeyStore.getInstance(keystoreType);
                ks.load(new FileInputStream(keystore), keystorePassword.toCharArray());
            } catch (Exception ex) {
                throw new FlumeException("Thrift source configured with invalid keystore: " + keystore, ex);
            }
        }

        principal = context.getString(AGENT_PRINCIPAL);
        String keytab = context.getString(AGENT_KEYTAB);
        enableKerberos = context.getBoolean(KERBEROS_KEY, false);
        this.flumeAuth = FlumeAuthenticationUtil.getAuthenticator(principal, keytab);
        if (enableKerberos) {
            if (!flumeAuth.isAuthenticated()) {
                throw new FlumeException("Authentication failed in Kerberos mode for " + "principal " + principal
                        + " keytab " + keytab);
            }
            flumeAuth.startCredentialRefresher();
        }
    }

    @Override
    public void start() {
        logger.info("Starting thrift source");

        // create the server
        server = getTThreadedSelectorServer();

        // if in ssl mode or if SelectorServer is unavailable
        if (server == null) {
            server = getTThreadPoolServer();
        }

        servingExecutor = Executors.newSingleThreadExecutor(
                new ThreadFactoryBuilder().setNameFormat("Flume Thrift Source I/O Boss").build());

        /**
         * Start serving.
         */
        servingExecutor.submit(new Runnable() {
            @Override
            public void run() {
                flumeAuth.execute(new PrivilegedAction<Object>() {
                    @Override
                    public Object run() {
                        server.serve();
                        return null;
                    }
                });
            }
        });

        long timeAfterStart = System.currentTimeMillis();
        while (!server.isServing()) {
            try {
                if (System.currentTimeMillis() - timeAfterStart >= 10000) {
                    throw new FlumeException("Thrift server failed to start!");
                }
                TimeUnit.MILLISECONDS.sleep(1000);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new FlumeException("Interrupted while waiting for Thrift server" + " to start.", e);
            }
        }
        sourceCounter.start();
        logger.info("Started Thrift source.");
        super.start();
    }

    private String getkeyManagerAlgorithm() {
        String algorithm = Security.getProperty("ssl.KeyManagerFactory.algorithm");
        return (algorithm != null) ? algorithm : KeyManagerFactory.getDefaultAlgorithm();
    }

    private TServerTransport getSSLServerTransport() {
        try {
            TServerTransport transport;
            TSSLTransportFactory.TSSLTransportParameters params = new TSSLTransportFactory.TSSLTransportParameters();

            params.setKeyStore(keystore, keystorePassword, getkeyManagerAlgorithm(), keystoreType);
            transport = TSSLTransportFactory.getServerSocket(port, 120000, InetAddress.getByName(bindAddress),
                    params);

            ServerSocket serverSock = ((TServerSocket) transport).getServerSocket();
            if (serverSock instanceof SSLServerSocket) {
                SSLServerSocket sslServerSock = (SSLServerSocket) serverSock;
                List<String> enabledProtocols = new ArrayList<String>();
                for (String protocol : sslServerSock.getEnabledProtocols()) {
                    if (!excludeProtocols.contains(protocol)) {
                        enabledProtocols.add(protocol);
                    }
                }
                sslServerSock.setEnabledProtocols(enabledProtocols.toArray(new String[0]));
            }
            return transport;
        } catch (Throwable throwable) {
            throw new FlumeException("Cannot start Thrift source.", throwable);
        }
    }

    private TServerTransport getTServerTransport() {
        try {
            return new TServerSocket(new InetSocketAddress(bindAddress, port));
        } catch (Throwable throwable) {
            throw new FlumeException("Cannot start Thrift source.", throwable);
        }
    }

    private TProtocolFactory getProtocolFactory() {
        if (protocol.equals(BINARY_PROTOCOL)) {
            logger.info("Using TBinaryProtocol");
            return new TBinaryProtocol.Factory();
        } else {
            logger.info("Using TCompactProtocol");
            return new TCompactProtocol.Factory();
        }
    }

    private TServer getTThreadedSelectorServer() {
        if (enableSsl || enableKerberos) {
            return null;
        }
        Class<?> serverClass;
        Class<?> argsClass;
        TServer.AbstractServerArgs args;
        try {
            serverClass = Class.forName("org.apache.thrift" + ".server.TThreadedSelectorServer");
            argsClass = Class.forName("org.apache.thrift" + ".server.TThreadedSelectorServer$Args");

            TServerTransport serverTransport = new TNonblockingServerSocket(
                    new InetSocketAddress(bindAddress, port));

            ExecutorService sourceService;
            ThreadFactory threadFactory = new ThreadFactoryBuilder().setNameFormat("Flume Thrift IPC Thread %d")
                    .build();
            if (maxThreads == 0) {
                sourceService = Executors.newCachedThreadPool(threadFactory);
            } else {
                sourceService = Executors.newFixedThreadPool(maxThreads, threadFactory);
            }
            args = (TNonblockingServer.AbstractNonblockingServerArgs) argsClass
                    .getConstructor(TNonblockingServerTransport.class).newInstance(serverTransport);
            Method m = argsClass.getDeclaredMethod("executorService", ExecutorService.class);
            m.invoke(args, sourceService);

            populateServerParams(args);

            /*
             * Both THsHaServer and TThreadedSelectorServer allows us to pass in
             * the executor service to use - unfortunately the "executorService"
             * method does not exist in the parent abstract Args class,
             * so use reflection to pass the executor in.
             *
             */
            server = (TServer) serverClass.getConstructor(argsClass).newInstance(args);
        } catch (ClassNotFoundException e) {
            return null;
        } catch (Throwable ex) {
            throw new FlumeException("Cannot start Thrift Source.", ex);
        }
        return server;
    }

    private TServer getTThreadPoolServer() {
        TServerTransport serverTransport;
        if (enableSsl) {
            serverTransport = getSSLServerTransport();
        } else {
            serverTransport = getTServerTransport();
        }
        TThreadPoolServer.Args serverArgs = new TThreadPoolServer.Args(serverTransport);
        serverArgs.maxWorkerThreads(maxThreads);
        populateServerParams(serverArgs);
        return new TThreadPoolServer(serverArgs);
    }

    private void populateServerParams(TServer.AbstractServerArgs args) {
        //populate the ProtocolFactory
        args.protocolFactory(getProtocolFactory());

        //populate the transportFactory
        if (enableKerberos) {
            args.transportFactory(getSASLTransportFactory());
        } else {
            args.transportFactory(new TFastFramedTransport.Factory());
        }

        // populate the  Processor
        args.processor(new ThriftSourceProtocol.Processor<ThriftSourceHandler>(new ThriftSourceHandler()));
    }

    private TTransportFactory getSASLTransportFactory() {
        String[] names;
        try {
            names = FlumeAuthenticationUtil.splitKerberosName(principal);
        } catch (IOException e) {
            throw new FlumeException("Error while trying to resolve Principal name - " + principal, e);
        }
        Map<String, String> saslProperties = new HashMap<String, String>();
        saslProperties.put(Sasl.QOP, "auth");
        TSaslServerTransport.Factory saslTransportFactory = new TSaslServerTransport.Factory();
        saslTransportFactory.addServerDefinition("GSSAPI", names[0], names[1], saslProperties,
                FlumeAuthenticationUtil.getSaslGssCallbackHandler());
        return saslTransportFactory;
    }

    @Override
    public void stop() {
        if (server != null && server.isServing()) {
            server.stop();
        }
        if (servingExecutor != null) {
            servingExecutor.shutdown();
            try {
                if (!servingExecutor.awaitTermination(5, TimeUnit.SECONDS)) {
                    servingExecutor.shutdownNow();
                }
            } catch (InterruptedException e) {
                throw new FlumeException("Interrupted while waiting for server to be " + "shutdown.");
            }
        }
        sourceCounter.stop();
        super.stop();
    }

    private class ThriftSourceHandler implements ThriftSourceProtocol.Iface {

        @Override
        public Status append(ThriftFlumeEvent event) throws TException {
            Event flumeEvent = EventBuilder.withBody(event.getBody(), event.getHeaders());

            sourceCounter.incrementAppendReceivedCount();
            sourceCounter.incrementEventReceivedCount();

            try {
                getChannelProcessor().processEvent(flumeEvent);
            } catch (ChannelException ex) {
                logger.warn("Thrift source " + getName() + " could not append events " + "to the channel.", ex);
                return Status.FAILED;
            }
            sourceCounter.incrementAppendAcceptedCount();
            sourceCounter.incrementEventAcceptedCount();
            return Status.OK;
        }

        @Override
        public Status appendBatch(List<ThriftFlumeEvent> events) throws TException {
            sourceCounter.incrementAppendBatchReceivedCount();
            sourceCounter.addToEventReceivedCount(events.size());

            List<Event> flumeEvents = Lists.newArrayList();
            for (ThriftFlumeEvent event : events) {
                flumeEvents.add(EventBuilder.withBody(event.getBody(), event.getHeaders()));
            }

            try {
                getChannelProcessor().processEventBatch(flumeEvents);
            } catch (ChannelException ex) {
                logger.warn("Thrift source %s could not append events to the " + "channel.", getName());
                return Status.FAILED;
            }

            sourceCounter.incrementAppendBatchAcceptedCount();
            sourceCounter.addToEventAcceptedCount(events.size());
            return Status.OK;
        }
    }
}