Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.tez.runtime.library.common.shuffle.server; import static org.jboss.netty.buffer.ChannelBuffers.wrappedBuffer; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.CONTENT_TYPE; import static org.jboss.netty.handler.codec.http.HttpMethod.GET; import static org.jboss.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST; import static org.jboss.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; import static org.jboss.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR; import static org.jboss.netty.handler.codec.http.HttpResponseStatus.METHOD_NOT_ALLOWED; import static org.jboss.netty.handler.codec.http.HttpResponseStatus.NOT_FOUND; import static org.jboss.netty.handler.codec.http.HttpResponseStatus.OK; import static org.jboss.netty.handler.codec.http.HttpResponseStatus.UNAUTHORIZED; import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1; import java.io.IOException; import java.net.InetSocketAddress; import java.net.URL; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executors; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import javax.crypto.SecretKey; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.DataInputByteBuffer; import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.metrics2.MetricsSystem; import org.apache.hadoop.metrics2.annotation.Metric; import org.apache.hadoop.metrics2.annotation.Metrics; import org.apache.hadoop.metrics2.lib.DefaultMetricsSystem; import org.apache.hadoop.metrics2.lib.MutableCounterInt; import org.apache.hadoop.metrics2.lib.MutableCounterLong; import org.apache.hadoop.metrics2.lib.MutableGaugeInt; import org.apache.hadoop.security.ssl.SSLFactory; import org.apache.hadoop.security.token.Token; import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext; import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext; import org.apache.hadoop.yarn.server.api.AuxiliaryService; import org.apache.tez.common.TezJobConfig; import org.apache.tez.common.security.JobTokenIdentifier; import org.apache.tez.runtime.api.TezOutputContext; import org.apache.tez.runtime.library.common.security.JobTokenSecretManager; import org.apache.tez.runtime.library.common.security.SecureShuffleUtils; import org.apache.tez.runtime.library.common.shuffle.impl.ShuffleHeader; import org.apache.tez.runtime.library.common.sort.impl.ExternalSorter; import org.apache.tez.runtime.library.shuffle.common.ShuffleUtils; import org.jboss.netty.bootstrap.ServerBootstrap; import org.jboss.netty.buffer.ChannelBuffers; import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.ChannelFactory; import org.jboss.netty.channel.ChannelFuture; import org.jboss.netty.channel.ChannelFutureListener; import org.jboss.netty.channel.ChannelHandlerContext; import org.jboss.netty.channel.ChannelPipeline; import org.jboss.netty.channel.ChannelPipelineFactory; import org.jboss.netty.channel.Channels; import org.jboss.netty.channel.ExceptionEvent; import org.jboss.netty.channel.MessageEvent; import org.jboss.netty.channel.SimpleChannelUpstreamHandler; import org.jboss.netty.channel.group.ChannelGroup; import org.jboss.netty.channel.group.DefaultChannelGroup; import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory; import org.jboss.netty.handler.codec.frame.TooLongFrameException; import org.jboss.netty.handler.codec.http.DefaultHttpResponse; import org.jboss.netty.handler.codec.http.HttpChunkAggregator; import org.jboss.netty.handler.codec.http.HttpRequest; import org.jboss.netty.handler.codec.http.HttpRequestDecoder; import org.jboss.netty.handler.codec.http.HttpResponse; import org.jboss.netty.handler.codec.http.HttpResponseEncoder; import org.jboss.netty.handler.codec.http.HttpResponseStatus; import org.jboss.netty.handler.codec.http.QueryStringDecoder; import org.jboss.netty.handler.ssl.SslHandler; import org.jboss.netty.handler.stream.ChunkedStream; import org.jboss.netty.handler.stream.ChunkedWriteHandler; import org.jboss.netty.util.CharsetUtil; import com.google.common.util.concurrent.ThreadFactoryBuilder; public class ShuffleHandler extends AuxiliaryService { private static final Log LOG = LogFactory.getLog(ShuffleHandler.class); public static final String SHUFFLE_MANAGE_OS_CACHE = "mapreduce.shuffle.manage.os.cache"; public static final boolean DEFAULT_SHUFFLE_MANAGE_OS_CACHE = true; public static final String SHUFFLE_READAHEAD_BYTES = "mapreduce.shuffle.readahead.bytes"; public static final int DEFAULT_SHUFFLE_READAHEAD_BYTES = 4 * 1024 * 1024; private int port; private ChannelFactory selector; private final ChannelGroup accepted = new DefaultChannelGroup(); private HttpPipelineFactory pipelineFact; private int sslFileBufferSize; public static final String MAPREDUCE_SHUFFLE_SERVICEID = "mapreduce_shuffle"; private static final Map<String, String> userRsrc = new ConcurrentHashMap<String, String>(); private static final JobTokenSecretManager secretManager = new JobTokenSecretManager(); private SecretKey tokenSecret; public static final String SHUFFLE_PORT_CONFIG_KEY = "mapreduce.shuffle.port"; public static final int DEFAULT_SHUFFLE_PORT = 8080; public static final String SUFFLE_SSL_FILE_BUFFER_SIZE_KEY = "mapreduce.shuffle.ssl.file.buffer.size"; public static final int DEFAULT_SUFFLE_SSL_FILE_BUFFER_SIZE = 60 * 1024; private ExternalSorter sorter; @Metrics(about = "Shuffle output metrics", context = "mapred") static class ShuffleMetrics implements ChannelFutureListener { @Metric("Shuffle output in bytes") MutableCounterLong shuffleOutputBytes; @Metric("# of failed shuffle outputs") MutableCounterInt shuffleOutputsFailed; @Metric("# of succeeeded shuffle outputs") MutableCounterInt shuffleOutputsOK; @Metric("# of current shuffle connections") MutableGaugeInt shuffleConnections; @Override public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { shuffleOutputsOK.incr(); } else { shuffleOutputsFailed.incr(); } shuffleConnections.decr(); } } final ShuffleMetrics metrics; ShuffleHandler(MetricsSystem ms) { super("httpshuffle"); metrics = ms.register(new ShuffleMetrics()); } public ShuffleHandler(ExternalSorter sorter) { this(DefaultMetricsSystem.instance()); this.sorter = sorter; } /** * Serialize the shuffle port into a ByteBuffer for use later on. * @param port the port to be sent to the ApplciationMaster * @return the serialized form of the port. */ public static ByteBuffer serializeMetaData(int port) throws IOException { //TODO these bytes should be versioned DataOutputBuffer port_dob = new DataOutputBuffer(); port_dob.writeInt(port); return ByteBuffer.wrap(port_dob.getData(), 0, port_dob.getLength()); } /** * A helper function to deserialize the metadata returned by ShuffleHandler. * @param meta the metadata returned by the ShuffleHandler * @return the port the Shuffle Handler is listening on to serve shuffle data. */ public static int deserializeMetaData(ByteBuffer meta) throws IOException { //TODO this should be returning a class not just an int DataInputByteBuffer in = new DataInputByteBuffer(); in.reset(meta); int port = in.readInt(); return port; } /** * A helper function to serialize the JobTokenIdentifier to be sent to the * ShuffleHandler as ServiceData. * @param jobToken the job token to be used for authentication of * shuffle data requests. * @return the serialized version of the jobToken. */ public static ByteBuffer serializeServiceData(Token<JobTokenIdentifier> jobToken) throws IOException { //TODO these bytes should be versioned DataOutputBuffer jobToken_dob = new DataOutputBuffer(); jobToken.write(jobToken_dob); return ByteBuffer.wrap(jobToken_dob.getData(), 0, jobToken_dob.getLength()); } static Token<JobTokenIdentifier> deserializeServiceData(ByteBuffer secret) throws IOException { DataInputByteBuffer in = new DataInputByteBuffer(); in.reset(secret); Token<JobTokenIdentifier> jt = new Token<JobTokenIdentifier>(); jt.readFields(in); return jt; } @Override public void initializeApplication(ApplicationInitializationContext initAppContext) { // TODO these bytes should be versioned try { String user = initAppContext.getUser(); ApplicationId appId = initAppContext.getApplicationId(); ByteBuffer secret = initAppContext.getApplicationDataForService(); Token<JobTokenIdentifier> jt = deserializeServiceData(secret); // TODO: Once SHuffle is out of NM, this can use MR APIs userRsrc.put(appId.toString(), user); LOG.info("Added token for " + appId.toString()); secretManager.addTokenForJob(appId.toString(), jt); } catch (IOException e) { LOG.error("Error during initApp", e); // TODO add API to AuxiliaryServices to report failures } } @Override public void stopApplication(ApplicationTerminationContext context) { ApplicationId appId = context.getApplicationId(); secretManager.removeTokenForJob(appId.toString()); userRsrc.remove(appId.toString()); } public void initialize(TezOutputContext outputContext, Configuration conf) throws IOException { this.init(new Configuration(conf)); tokenSecret = ShuffleUtils.getJobTokenSecretFromTokenBytes( outputContext.getServiceConsumerMetaData(MAPREDUCE_SHUFFLE_SERVICEID)); } @Override public synchronized void serviceInit(Configuration conf) { ThreadFactory bossFactory = new ThreadFactoryBuilder().setNameFormat("ShuffleHandler Netty Boss #%d") .build(); ThreadFactory workerFactory = new ThreadFactoryBuilder().setNameFormat("ShuffleHandler Netty Worker #%d") .build(); selector = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(bossFactory), Executors.newCachedThreadPool(workerFactory)); } // TODO change AbstractService to throw InterruptedException @Override public synchronized void serviceStart() { Configuration conf = getConfig(); ServerBootstrap bootstrap = new ServerBootstrap(selector); try { pipelineFact = new HttpPipelineFactory(conf); } catch (Exception ex) { throw new RuntimeException(ex); } bootstrap.setPipelineFactory(pipelineFact); // Let OS pick the port Channel ch = bootstrap.bind(new InetSocketAddress(0)); accepted.add(ch); port = ((InetSocketAddress) ch.getLocalAddress()).getPort(); conf.set(SHUFFLE_PORT_CONFIG_KEY, Integer.toString(port)); pipelineFact.SHUFFLE.setPort(port); LOG.info(getName() + " listening on port " + port); sslFileBufferSize = conf.getInt(SUFFLE_SSL_FILE_BUFFER_SIZE_KEY, DEFAULT_SUFFLE_SSL_FILE_BUFFER_SIZE); } @Override public synchronized void serviceStop() { accepted.close().awaitUninterruptibly(10, TimeUnit.SECONDS); ServerBootstrap bootstrap = new ServerBootstrap(selector); bootstrap.releaseExternalResources(); pipelineFact.destroy(); } @Override public synchronized ByteBuffer getMetaData() { try { return serializeMetaData(port); } catch (IOException e) { LOG.error("Error during getMeta", e); // TODO add API to AuxiliaryServices to report failures return null; } } class HttpPipelineFactory implements ChannelPipelineFactory { final Shuffle SHUFFLE; private SSLFactory sslFactory; public HttpPipelineFactory(Configuration conf) throws Exception { SHUFFLE = new Shuffle(conf); if (conf.getBoolean(TezJobConfig.TEZ_RUNTIME_SHUFFLE_ENABLE_SSL, TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_ENABLE_SSL)) { sslFactory = new SSLFactory(SSLFactory.Mode.SERVER, conf); sslFactory.init(); } } public void destroy() { if (sslFactory != null) { sslFactory.destroy(); } } @Override public ChannelPipeline getPipeline() throws Exception { ChannelPipeline pipeline = Channels.pipeline(); if (sslFactory != null) { pipeline.addLast("ssl", new SslHandler(sslFactory.createSSLEngine())); } pipeline.addLast("decoder", new HttpRequestDecoder()); pipeline.addLast("aggregator", new HttpChunkAggregator(1 << 16)); pipeline.addLast("encoder", new HttpResponseEncoder()); pipeline.addLast("chunking", new ChunkedWriteHandler()); pipeline.addLast("shuffle", SHUFFLE); return pipeline; // TODO factor security manager into pipeline // TODO factor out encode/decode to permit binary shuffle // TODO factor out decode of index to permit alt. models } } class Shuffle extends SimpleChannelUpstreamHandler { private final Configuration conf; private int port; public Shuffle(Configuration conf) { this.conf = conf; this.port = conf.getInt(SHUFFLE_PORT_CONFIG_KEY, DEFAULT_SHUFFLE_PORT); } public void setPort(int port) { this.port = port; } private List<String> splitMaps(List<String> mapq) { if (null == mapq) { return null; } final List<String> ret = new ArrayList<String>(); for (String s : mapq) { Collections.addAll(ret, s.split(",")); } return ret; } @Override public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt) throws Exception { HttpRequest request = (HttpRequest) evt.getMessage(); if (request.getMethod() != GET) { sendError(ctx, METHOD_NOT_ALLOWED); return; } // Check whether the shuffle version is compatible if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals(request.getHeader(ShuffleHeader.HTTP_HEADER_NAME)) || !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION .equals(request.getHeader(ShuffleHeader.HTTP_HEADER_VERSION))) { sendError(ctx, "Incompatible shuffle request version", BAD_REQUEST); } final Map<String, List<String>> q = new QueryStringDecoder(request.getUri()).getParameters(); final List<String> mapIds = splitMaps(q.get("map")); final List<String> reduceQ = q.get("reduce"); final List<String> jobQ = q.get("job"); if (LOG.isDebugEnabled()) { LOG.debug("RECV: " + request.getUri() + "\n mapId: " + mapIds + "\n reduceId: " + reduceQ + "\n jobId: " + jobQ); } if (mapIds == null || reduceQ == null || jobQ == null) { sendError(ctx, "Required param job, map and reduce", BAD_REQUEST); return; } if (reduceQ.size() != 1 || jobQ.size() != 1) { sendError(ctx, "Too many job/reduce parameters", BAD_REQUEST); return; } int reduceId; String jobId; try { reduceId = Integer.parseInt(reduceQ.get(0)); jobId = jobQ.get(0); } catch (NumberFormatException e) { sendError(ctx, "Bad reduce parameter", BAD_REQUEST); return; } catch (IllegalArgumentException e) { sendError(ctx, "Bad job parameter", BAD_REQUEST); return; } final String reqUri = request.getUri(); if (null == reqUri) { // TODO? add upstream? sendError(ctx, FORBIDDEN); return; } HttpResponse response = new DefaultHttpResponse(HTTP_1_1, OK); try { verifyRequest(jobId, ctx, request, response, new URL("http", "", this.port, reqUri)); } catch (IOException e) { LOG.warn("Shuffle failure ", e); sendError(ctx, e.getMessage(), UNAUTHORIZED); return; } Channel ch = evt.getChannel(); ch.write(response); // TODO refactor the following into the pipeline ChannelFuture lastMap = null; for (String mapId : mapIds) { try { // TODO: Error handling - validate mapId via TezTaskAttemptId.forName // TODO NEWTEZ Fix this. TaskAttemptId is no longer valid. mapId validation will not work anymore. // if (!mapId.equals(sorter.getTaskAttemptId().toString())) { // String errorMessage = // "Illegal shuffle request mapId: " + mapId // + " while actual mapId is " + sorter.getTaskAttemptId(); // LOG.warn(errorMessage); // sendError(ctx, errorMessage, BAD_REQUEST); // return; // } lastMap = sendMapOutput(ctx, ch, userRsrc.get(jobId), jobId, mapId, reduceId); if (null == lastMap) { sendError(ctx, NOT_FOUND); return; } } catch (IOException e) { LOG.error("Shuffle error ", e); sendError(ctx, e.getMessage(), INTERNAL_SERVER_ERROR); return; } } lastMap.addListener(metrics); lastMap.addListener(ChannelFutureListener.CLOSE); } private void verifyRequest(String appid, ChannelHandlerContext ctx, HttpRequest request, HttpResponse response, URL requestUri) throws IOException { if (null == tokenSecret) { LOG.info("Request for unknown token " + appid); throw new IOException("could not find jobid"); } // string to encrypt String enc_str = SecureShuffleUtils.buildMsgFrom(requestUri); // hash from the fetcher String urlHashStr = request.getHeader(SecureShuffleUtils.HTTP_HEADER_URL_HASH); if (urlHashStr == null) { LOG.info("Missing header hash for " + appid); throw new IOException("fetcher cannot be authenticated"); } if (LOG.isDebugEnabled()) { int len = urlHashStr.length(); LOG.debug("verifying request. enc_str=" + enc_str + "; hash=..." + urlHashStr.substring(len - len / 2, len - 1)); } // verify - throws exception SecureShuffleUtils.verifyReply(urlHashStr, enc_str, tokenSecret); // verification passed - encode the reply String reply = SecureShuffleUtils.generateHash(urlHashStr.getBytes(), tokenSecret); response.setHeader(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH, reply); addVersionToHeader(response); if (LOG.isDebugEnabled()) { int len = reply.length(); LOG.debug("Fetcher request verfied. enc_str=" + enc_str + ";reply=" + reply.substring(len - len / 2, len - 1)); } } protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, Channel ch, String user, String jobId, String mapId, int reduce) throws IOException { final ShuffleHeader header = sorter.getShuffleHeader(reduce); final DataOutputBuffer dob = new DataOutputBuffer(); header.write(dob); ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength())); ChannelFuture writeFuture = ch .write(new ChunkedStream(sorter.getSortedStream(reduce), sslFileBufferSize)); metrics.shuffleConnections.incr(); metrics.shuffleOutputBytes.incr(header.getCompressedLength()); // optimistic return writeFuture; } private void sendError(ChannelHandlerContext ctx, HttpResponseStatus status) { sendError(ctx, "", status); } private void sendError(ChannelHandlerContext ctx, String message, HttpResponseStatus status) { HttpResponse response = new DefaultHttpResponse(HTTP_1_1, status); response.setHeader(CONTENT_TYPE, "text/plain; charset=UTF-8"); addVersionToHeader(response); response.setContent(ChannelBuffers.copiedBuffer(message, CharsetUtil.UTF_8)); // Close the connection as soon as the error message is sent. ctx.getChannel().write(response).addListener(ChannelFutureListener.CLOSE); } private void addVersionToHeader(HttpResponse response) { // Put shuffle version into http header response.setHeader(ShuffleHeader.HTTP_HEADER_NAME, ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); response.setHeader(ShuffleHeader.HTTP_HEADER_VERSION, ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); } @Override public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception { Channel ch = e.getChannel(); Throwable cause = e.getCause(); if (cause instanceof TooLongFrameException) { sendError(ctx, BAD_REQUEST); return; } LOG.error("Shuffle error: ", cause); if (ch.isConnected()) { LOG.error("Shuffle error " + e); sendError(ctx, INTERNAL_SERVER_ERROR); } } } public int getPort() { return port; } }