org.apache.tez.runtime.library.common.shuffle.server.ShuffleHandler.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.tez.runtime.library.common.shuffle.server.ShuffleHandler.java

Source

/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.tez.runtime.library.common.shuffle.server;

import static org.jboss.netty.buffer.ChannelBuffers.wrappedBuffer;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.CONTENT_TYPE;
import static org.jboss.netty.handler.codec.http.HttpMethod.GET;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.METHOD_NOT_ALLOWED;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.NOT_FOUND;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.OK;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.UNAUTHORIZED;
import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URL;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;

import javax.crypto.SecretKey;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.DataInputByteBuffer;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.metrics2.MetricsSystem;
import org.apache.hadoop.metrics2.annotation.Metric;
import org.apache.hadoop.metrics2.annotation.Metrics;
import org.apache.hadoop.metrics2.lib.DefaultMetricsSystem;
import org.apache.hadoop.metrics2.lib.MutableCounterInt;
import org.apache.hadoop.metrics2.lib.MutableCounterLong;
import org.apache.hadoop.metrics2.lib.MutableGaugeInt;
import org.apache.hadoop.security.ssl.SSLFactory;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext;
import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext;
import org.apache.hadoop.yarn.server.api.AuxiliaryService;
import org.apache.tez.common.TezJobConfig;
import org.apache.tez.common.security.JobTokenIdentifier;
import org.apache.tez.runtime.api.TezOutputContext;
import org.apache.tez.runtime.library.common.security.JobTokenSecretManager;
import org.apache.tez.runtime.library.common.security.SecureShuffleUtils;
import org.apache.tez.runtime.library.common.shuffle.impl.ShuffleHeader;
import org.apache.tez.runtime.library.common.sort.impl.ExternalSorter;
import org.apache.tez.runtime.library.shuffle.common.ShuffleUtils;
import org.jboss.netty.bootstrap.ServerBootstrap;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFactory;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelFutureListener;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelPipelineFactory;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.ExceptionEvent;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.channel.group.ChannelGroup;
import org.jboss.netty.channel.group.DefaultChannelGroup;
import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory;
import org.jboss.netty.handler.codec.frame.TooLongFrameException;
import org.jboss.netty.handler.codec.http.DefaultHttpResponse;
import org.jboss.netty.handler.codec.http.HttpChunkAggregator;
import org.jboss.netty.handler.codec.http.HttpRequest;
import org.jboss.netty.handler.codec.http.HttpRequestDecoder;
import org.jboss.netty.handler.codec.http.HttpResponse;
import org.jboss.netty.handler.codec.http.HttpResponseEncoder;
import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import org.jboss.netty.handler.codec.http.QueryStringDecoder;
import org.jboss.netty.handler.ssl.SslHandler;
import org.jboss.netty.handler.stream.ChunkedStream;
import org.jboss.netty.handler.stream.ChunkedWriteHandler;
import org.jboss.netty.util.CharsetUtil;

import com.google.common.util.concurrent.ThreadFactoryBuilder;

public class ShuffleHandler extends AuxiliaryService {

    private static final Log LOG = LogFactory.getLog(ShuffleHandler.class);

    public static final String SHUFFLE_MANAGE_OS_CACHE = "mapreduce.shuffle.manage.os.cache";
    public static final boolean DEFAULT_SHUFFLE_MANAGE_OS_CACHE = true;

    public static final String SHUFFLE_READAHEAD_BYTES = "mapreduce.shuffle.readahead.bytes";
    public static final int DEFAULT_SHUFFLE_READAHEAD_BYTES = 4 * 1024 * 1024;

    private int port;
    private ChannelFactory selector;
    private final ChannelGroup accepted = new DefaultChannelGroup();
    private HttpPipelineFactory pipelineFact;
    private int sslFileBufferSize;

    public static final String MAPREDUCE_SHUFFLE_SERVICEID = "mapreduce_shuffle";

    private static final Map<String, String> userRsrc = new ConcurrentHashMap<String, String>();
    private static final JobTokenSecretManager secretManager = new JobTokenSecretManager();
    private SecretKey tokenSecret;

    public static final String SHUFFLE_PORT_CONFIG_KEY = "mapreduce.shuffle.port";
    public static final int DEFAULT_SHUFFLE_PORT = 8080;

    public static final String SUFFLE_SSL_FILE_BUFFER_SIZE_KEY = "mapreduce.shuffle.ssl.file.buffer.size";

    public static final int DEFAULT_SUFFLE_SSL_FILE_BUFFER_SIZE = 60 * 1024;

    private ExternalSorter sorter;

    @Metrics(about = "Shuffle output metrics", context = "mapred")
    static class ShuffleMetrics implements ChannelFutureListener {
        @Metric("Shuffle output in bytes")
        MutableCounterLong shuffleOutputBytes;
        @Metric("# of failed shuffle outputs")
        MutableCounterInt shuffleOutputsFailed;
        @Metric("# of succeeeded shuffle outputs")
        MutableCounterInt shuffleOutputsOK;
        @Metric("# of current shuffle connections")
        MutableGaugeInt shuffleConnections;

        @Override
        public void operationComplete(ChannelFuture future) throws Exception {
            if (future.isSuccess()) {
                shuffleOutputsOK.incr();
            } else {
                shuffleOutputsFailed.incr();
            }
            shuffleConnections.decr();
        }
    }

    final ShuffleMetrics metrics;

    ShuffleHandler(MetricsSystem ms) {
        super("httpshuffle");
        metrics = ms.register(new ShuffleMetrics());
    }

    public ShuffleHandler(ExternalSorter sorter) {
        this(DefaultMetricsSystem.instance());
        this.sorter = sorter;
    }

    /**
     * Serialize the shuffle port into a ByteBuffer for use later on.
     * @param port the port to be sent to the ApplciationMaster
     * @return the serialized form of the port.
     */
    public static ByteBuffer serializeMetaData(int port) throws IOException {
        //TODO these bytes should be versioned
        DataOutputBuffer port_dob = new DataOutputBuffer();
        port_dob.writeInt(port);
        return ByteBuffer.wrap(port_dob.getData(), 0, port_dob.getLength());
    }

    /**
     * A helper function to deserialize the metadata returned by ShuffleHandler.
     * @param meta the metadata returned by the ShuffleHandler
     * @return the port the Shuffle Handler is listening on to serve shuffle data.
     */
    public static int deserializeMetaData(ByteBuffer meta) throws IOException {
        //TODO this should be returning a class not just an int
        DataInputByteBuffer in = new DataInputByteBuffer();
        in.reset(meta);
        int port = in.readInt();
        return port;
    }

    /**
     * A helper function to serialize the JobTokenIdentifier to be sent to the
     * ShuffleHandler as ServiceData.
     * @param jobToken the job token to be used for authentication of
     * shuffle data requests.
     * @return the serialized version of the jobToken.
     */
    public static ByteBuffer serializeServiceData(Token<JobTokenIdentifier> jobToken) throws IOException {
        //TODO these bytes should be versioned
        DataOutputBuffer jobToken_dob = new DataOutputBuffer();
        jobToken.write(jobToken_dob);
        return ByteBuffer.wrap(jobToken_dob.getData(), 0, jobToken_dob.getLength());
    }

    static Token<JobTokenIdentifier> deserializeServiceData(ByteBuffer secret) throws IOException {
        DataInputByteBuffer in = new DataInputByteBuffer();
        in.reset(secret);
        Token<JobTokenIdentifier> jt = new Token<JobTokenIdentifier>();
        jt.readFields(in);
        return jt;
    }

    @Override
    public void initializeApplication(ApplicationInitializationContext initAppContext) {
        // TODO these bytes should be versioned
        try {
            String user = initAppContext.getUser();
            ApplicationId appId = initAppContext.getApplicationId();
            ByteBuffer secret = initAppContext.getApplicationDataForService();
            Token<JobTokenIdentifier> jt = deserializeServiceData(secret);
            // TODO: Once SHuffle is out of NM, this can use MR APIs
            userRsrc.put(appId.toString(), user);
            LOG.info("Added token for " + appId.toString());
            secretManager.addTokenForJob(appId.toString(), jt);
        } catch (IOException e) {
            LOG.error("Error during initApp", e);
            // TODO add API to AuxiliaryServices to report failures
        }
    }

    @Override
    public void stopApplication(ApplicationTerminationContext context) {
        ApplicationId appId = context.getApplicationId();
        secretManager.removeTokenForJob(appId.toString());
        userRsrc.remove(appId.toString());
    }

    public void initialize(TezOutputContext outputContext, Configuration conf) throws IOException {
        this.init(new Configuration(conf));
        tokenSecret = ShuffleUtils.getJobTokenSecretFromTokenBytes(
                outputContext.getServiceConsumerMetaData(MAPREDUCE_SHUFFLE_SERVICEID));
    }

    @Override
    public synchronized void serviceInit(Configuration conf) {
        ThreadFactory bossFactory = new ThreadFactoryBuilder().setNameFormat("ShuffleHandler Netty Boss #%d")
                .build();
        ThreadFactory workerFactory = new ThreadFactoryBuilder().setNameFormat("ShuffleHandler Netty Worker #%d")
                .build();

        selector = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(bossFactory),
                Executors.newCachedThreadPool(workerFactory));
    }

    // TODO change AbstractService to throw InterruptedException
    @Override
    public synchronized void serviceStart() {
        Configuration conf = getConfig();
        ServerBootstrap bootstrap = new ServerBootstrap(selector);
        try {
            pipelineFact = new HttpPipelineFactory(conf);
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
        bootstrap.setPipelineFactory(pipelineFact);
        // Let OS pick the port
        Channel ch = bootstrap.bind(new InetSocketAddress(0));
        accepted.add(ch);
        port = ((InetSocketAddress) ch.getLocalAddress()).getPort();
        conf.set(SHUFFLE_PORT_CONFIG_KEY, Integer.toString(port));
        pipelineFact.SHUFFLE.setPort(port);
        LOG.info(getName() + " listening on port " + port);

        sslFileBufferSize = conf.getInt(SUFFLE_SSL_FILE_BUFFER_SIZE_KEY, DEFAULT_SUFFLE_SSL_FILE_BUFFER_SIZE);
    }

    @Override
    public synchronized void serviceStop() {
        accepted.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
        ServerBootstrap bootstrap = new ServerBootstrap(selector);
        bootstrap.releaseExternalResources();
        pipelineFact.destroy();
    }

    @Override
    public synchronized ByteBuffer getMetaData() {
        try {
            return serializeMetaData(port);
        } catch (IOException e) {
            LOG.error("Error during getMeta", e);
            // TODO add API to AuxiliaryServices to report failures
            return null;
        }
    }

    class HttpPipelineFactory implements ChannelPipelineFactory {

        final Shuffle SHUFFLE;
        private SSLFactory sslFactory;

        public HttpPipelineFactory(Configuration conf) throws Exception {
            SHUFFLE = new Shuffle(conf);
            if (conf.getBoolean(TezJobConfig.TEZ_RUNTIME_SHUFFLE_ENABLE_SSL,
                    TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_ENABLE_SSL)) {
                sslFactory = new SSLFactory(SSLFactory.Mode.SERVER, conf);
                sslFactory.init();
            }
        }

        public void destroy() {
            if (sslFactory != null) {
                sslFactory.destroy();
            }
        }

        @Override
        public ChannelPipeline getPipeline() throws Exception {
            ChannelPipeline pipeline = Channels.pipeline();
            if (sslFactory != null) {
                pipeline.addLast("ssl", new SslHandler(sslFactory.createSSLEngine()));
            }
            pipeline.addLast("decoder", new HttpRequestDecoder());
            pipeline.addLast("aggregator", new HttpChunkAggregator(1 << 16));
            pipeline.addLast("encoder", new HttpResponseEncoder());
            pipeline.addLast("chunking", new ChunkedWriteHandler());
            pipeline.addLast("shuffle", SHUFFLE);
            return pipeline;
            // TODO factor security manager into pipeline
            // TODO factor out encode/decode to permit binary shuffle
            // TODO factor out decode of index to permit alt. models
        }

    }

    class Shuffle extends SimpleChannelUpstreamHandler {

        private final Configuration conf;
        private int port;

        public Shuffle(Configuration conf) {
            this.conf = conf;
            this.port = conf.getInt(SHUFFLE_PORT_CONFIG_KEY, DEFAULT_SHUFFLE_PORT);
        }

        public void setPort(int port) {
            this.port = port;
        }

        private List<String> splitMaps(List<String> mapq) {
            if (null == mapq) {
                return null;
            }
            final List<String> ret = new ArrayList<String>();
            for (String s : mapq) {
                Collections.addAll(ret, s.split(","));
            }
            return ret;
        }

        @Override
        public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt) throws Exception {
            HttpRequest request = (HttpRequest) evt.getMessage();
            if (request.getMethod() != GET) {
                sendError(ctx, METHOD_NOT_ALLOWED);
                return;
            }
            // Check whether the shuffle version is compatible
            if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals(request.getHeader(ShuffleHeader.HTTP_HEADER_NAME))
                    || !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION
                            .equals(request.getHeader(ShuffleHeader.HTTP_HEADER_VERSION))) {
                sendError(ctx, "Incompatible shuffle request version", BAD_REQUEST);
            }
            final Map<String, List<String>> q = new QueryStringDecoder(request.getUri()).getParameters();
            final List<String> mapIds = splitMaps(q.get("map"));
            final List<String> reduceQ = q.get("reduce");
            final List<String> jobQ = q.get("job");
            if (LOG.isDebugEnabled()) {
                LOG.debug("RECV: " + request.getUri() + "\n  mapId: " + mapIds + "\n  reduceId: " + reduceQ
                        + "\n  jobId: " + jobQ);
            }

            if (mapIds == null || reduceQ == null || jobQ == null) {
                sendError(ctx, "Required param job, map and reduce", BAD_REQUEST);
                return;
            }
            if (reduceQ.size() != 1 || jobQ.size() != 1) {
                sendError(ctx, "Too many job/reduce parameters", BAD_REQUEST);
                return;
            }
            int reduceId;
            String jobId;
            try {
                reduceId = Integer.parseInt(reduceQ.get(0));
                jobId = jobQ.get(0);
            } catch (NumberFormatException e) {
                sendError(ctx, "Bad reduce parameter", BAD_REQUEST);
                return;
            } catch (IllegalArgumentException e) {
                sendError(ctx, "Bad job parameter", BAD_REQUEST);
                return;
            }

            final String reqUri = request.getUri();
            if (null == reqUri) {
                // TODO? add upstream?
                sendError(ctx, FORBIDDEN);
                return;
            }
            HttpResponse response = new DefaultHttpResponse(HTTP_1_1, OK);
            try {
                verifyRequest(jobId, ctx, request, response, new URL("http", "", this.port, reqUri));
            } catch (IOException e) {
                LOG.warn("Shuffle failure ", e);
                sendError(ctx, e.getMessage(), UNAUTHORIZED);
                return;
            }

            Channel ch = evt.getChannel();
            ch.write(response);
            // TODO refactor the following into the pipeline
            ChannelFuture lastMap = null;
            for (String mapId : mapIds) {
                try {
                    // TODO: Error handling - validate mapId via TezTaskAttemptId.forName

                    // TODO NEWTEZ Fix this. TaskAttemptId is no longer valid. mapId validation will not work anymore.
                    //          if (!mapId.equals(sorter.getTaskAttemptId().toString())) {
                    //            String errorMessage =
                    //                "Illegal shuffle request mapId: " + mapId
                    //                    + " while actual mapId is " + sorter.getTaskAttemptId(); 
                    //            LOG.warn(errorMessage);
                    //            sendError(ctx, errorMessage, BAD_REQUEST);
                    //            return;
                    //          }

                    lastMap = sendMapOutput(ctx, ch, userRsrc.get(jobId), jobId, mapId, reduceId);
                    if (null == lastMap) {
                        sendError(ctx, NOT_FOUND);
                        return;
                    }
                } catch (IOException e) {
                    LOG.error("Shuffle error ", e);
                    sendError(ctx, e.getMessage(), INTERNAL_SERVER_ERROR);
                    return;
                }
            }
            lastMap.addListener(metrics);
            lastMap.addListener(ChannelFutureListener.CLOSE);
        }

        private void verifyRequest(String appid, ChannelHandlerContext ctx, HttpRequest request,
                HttpResponse response, URL requestUri) throws IOException {
            if (null == tokenSecret) {
                LOG.info("Request for unknown token " + appid);
                throw new IOException("could not find jobid");
            }
            // string to encrypt
            String enc_str = SecureShuffleUtils.buildMsgFrom(requestUri);
            // hash from the fetcher
            String urlHashStr = request.getHeader(SecureShuffleUtils.HTTP_HEADER_URL_HASH);
            if (urlHashStr == null) {
                LOG.info("Missing header hash for " + appid);
                throw new IOException("fetcher cannot be authenticated");
            }
            if (LOG.isDebugEnabled()) {
                int len = urlHashStr.length();
                LOG.debug("verifying request. enc_str=" + enc_str + "; hash=..."
                        + urlHashStr.substring(len - len / 2, len - 1));
            }
            // verify - throws exception
            SecureShuffleUtils.verifyReply(urlHashStr, enc_str, tokenSecret);
            // verification passed - encode the reply
            String reply = SecureShuffleUtils.generateHash(urlHashStr.getBytes(), tokenSecret);
            response.setHeader(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH, reply);
            addVersionToHeader(response);
            if (LOG.isDebugEnabled()) {
                int len = reply.length();
                LOG.debug("Fetcher request verfied. enc_str=" + enc_str + ";reply="
                        + reply.substring(len - len / 2, len - 1));
            }
        }

        protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, Channel ch, String user, String jobId,
                String mapId, int reduce) throws IOException {
            final ShuffleHeader header = sorter.getShuffleHeader(reduce);
            final DataOutputBuffer dob = new DataOutputBuffer();
            header.write(dob);
            ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength()));

            ChannelFuture writeFuture = ch
                    .write(new ChunkedStream(sorter.getSortedStream(reduce), sslFileBufferSize));
            metrics.shuffleConnections.incr();
            metrics.shuffleOutputBytes.incr(header.getCompressedLength()); // optimistic
            return writeFuture;
        }

        private void sendError(ChannelHandlerContext ctx, HttpResponseStatus status) {
            sendError(ctx, "", status);
        }

        private void sendError(ChannelHandlerContext ctx, String message, HttpResponseStatus status) {
            HttpResponse response = new DefaultHttpResponse(HTTP_1_1, status);
            response.setHeader(CONTENT_TYPE, "text/plain; charset=UTF-8");
            addVersionToHeader(response);
            response.setContent(ChannelBuffers.copiedBuffer(message, CharsetUtil.UTF_8));
            // Close the connection as soon as the error message is sent.
            ctx.getChannel().write(response).addListener(ChannelFutureListener.CLOSE);
        }

        private void addVersionToHeader(HttpResponse response) {
            // Put shuffle version into http header
            response.setHeader(ShuffleHeader.HTTP_HEADER_NAME, ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
            response.setHeader(ShuffleHeader.HTTP_HEADER_VERSION, ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception {
            Channel ch = e.getChannel();
            Throwable cause = e.getCause();
            if (cause instanceof TooLongFrameException) {
                sendError(ctx, BAD_REQUEST);
                return;
            }

            LOG.error("Shuffle error: ", cause);
            if (ch.isConnected()) {
                LOG.error("Shuffle error " + e);
                sendError(ctx, INTERNAL_SERVER_ERROR);
            }
        }

    }

    public int getPort() {
        return port;
    }
}