io.prestosql.proxy.ProxyResource.java Source code

Java tutorial

Introduction

Here is the source code for io.prestosql.proxy.ProxyResource.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.prestosql.proxy;

import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonToken;
import com.google.common.hash.HashCode;
import com.google.common.hash.HashFunction;
import com.google.common.util.concurrent.FluentFuture;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.Request;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.prestosql.proxy.ProxyResponseHandler.ProxyResponse;

import javax.annotation.PreDestroy;
import javax.inject.Inject;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.DELETE;
import javax.ws.rs.GET;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.container.AsyncResponse;
import javax.ws.rs.container.Suspended;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.ResponseBuilder;
import javax.ws.rs.core.Response.Status;
import javax.ws.rs.core.UriInfo;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.security.cert.X509Certificate;
import java.util.Base64;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;

import static com.fasterxml.jackson.core.JsonFactory.Feature.CANONICALIZE_FIELD_NAMES;
import static com.fasterxml.jackson.core.JsonToken.END_OBJECT;
import static com.fasterxml.jackson.core.JsonToken.FIELD_NAME;
import static com.fasterxml.jackson.core.JsonToken.START_OBJECT;
import static com.fasterxml.jackson.core.JsonToken.VALUE_STRING;
import static com.google.common.hash.Hashing.hmacSha256;
import static com.google.common.net.HttpHeaders.AUTHORIZATION;
import static com.google.common.net.HttpHeaders.COOKIE;
import static com.google.common.net.HttpHeaders.SET_COOKIE;
import static com.google.common.net.HttpHeaders.USER_AGENT;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom;
import static io.airlift.http.client.Request.Builder.prepareDelete;
import static io.airlift.http.client.Request.Builder.prepareGet;
import static io.airlift.http.client.Request.Builder.preparePost;
import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator;
import static io.airlift.http.server.AsyncResponseHandler.bindAsyncResponse;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.nio.file.Files.readAllBytes;
import static java.util.Collections.list;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.TimeUnit.MINUTES;
import static javax.ws.rs.core.MediaType.APPLICATION_JSON;
import static javax.ws.rs.core.MediaType.TEXT_PLAIN_TYPE;
import static javax.ws.rs.core.Response.Status.BAD_GATEWAY;
import static javax.ws.rs.core.Response.Status.FORBIDDEN;
import static javax.ws.rs.core.Response.noContent;

@Path("/")
public class ProxyResource {
    private static final Logger log = Logger.get(ProxyResource.class);

    private static final String X509_ATTRIBUTE = "javax.servlet.request.X509Certificate";
    private static final Duration ASYNC_TIMEOUT = new Duration(2, MINUTES);
    private static final JsonFactory JSON_FACTORY = new JsonFactory().disable(CANONICALIZE_FIELD_NAMES);

    private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed("proxy-%s"));
    private final HttpClient httpClient;
    private final JsonWebTokenHandler jwtHandler;
    private final URI remoteUri;
    private final HashFunction hmac;

    @Inject
    public ProxyResource(@ForProxy HttpClient httpClient, JsonWebTokenHandler jwtHandler, ProxyConfig config) {
        this.httpClient = requireNonNull(httpClient, "httpClient is null");
        this.jwtHandler = requireNonNull(jwtHandler, "jwtHandler is null");
        this.remoteUri = requireNonNull(config.getUri(), "uri is null");
        this.hmac = hmacSha256(loadSharedSecret(config.getSharedSecretFile()));
    }

    @PreDestroy
    public void shutdown() {
        executor.shutdownNow();
    }

    @GET
    @Path("/v1/info")
    @Produces(APPLICATION_JSON)
    public void getInfo(@Context HttpServletRequest servletRequest, @Suspended AsyncResponse asyncResponse) {
        Request.Builder request = prepareGet().setUri(uriBuilderFrom(remoteUri).replacePath("/v1/info").build());

        performRequest(servletRequest, asyncResponse, request,
                response -> responseWithHeaders(Response.ok(response.getBody()), response));
    }

    @POST
    @Path("/v1/statement")
    @Produces(APPLICATION_JSON)
    public void postStatement(String statement, @Context HttpServletRequest servletRequest,
            @Context UriInfo uriInfo, @Suspended AsyncResponse asyncResponse) {
        Request.Builder request = preparePost()
                .setUri(uriBuilderFrom(remoteUri).replacePath("/v1/statement").build())
                .setBodyGenerator(createStaticBodyGenerator(statement, UTF_8));

        performRequest(servletRequest, asyncResponse, request, response -> buildResponse(uriInfo, response));
    }

    @GET
    @Path("/v1/proxy")
    @Produces(APPLICATION_JSON)
    public void getNext(@QueryParam("uri") String uri, @QueryParam("hmac") String hash,
            @Context HttpServletRequest servletRequest, @Context UriInfo uriInfo,
            @Suspended AsyncResponse asyncResponse) {
        if (!hmac.hashString(uri, UTF_8).equals(HashCode.fromString(hash))) {
            throw badRequest(FORBIDDEN, "Failed to validate HMAC of URI");
        }

        Request.Builder request = prepareGet().setUri(URI.create(uri));

        performRequest(servletRequest, asyncResponse, request, response -> buildResponse(uriInfo, response));
    }

    @DELETE
    @Path("/v1/proxy")
    @Produces(APPLICATION_JSON)
    public void cancelQuery(@QueryParam("uri") String uri, @QueryParam("hmac") String hash,
            @Context HttpServletRequest servletRequest, @Suspended AsyncResponse asyncResponse) {
        if (!hmac.hashString(uri, UTF_8).equals(HashCode.fromString(hash))) {
            throw badRequest(FORBIDDEN, "Failed to validate HMAC of URI");
        }

        Request.Builder request = prepareDelete().setUri(URI.create(uri));

        performRequest(servletRequest, asyncResponse, request,
                response -> responseWithHeaders(noContent(), response));
    }

    private void performRequest(HttpServletRequest servletRequest, AsyncResponse asyncResponse,
            Request.Builder requestBuilder, Function<ProxyResponse, Response> responseBuilder) {
        setupBearerToken(servletRequest, requestBuilder);

        for (String name : list(servletRequest.getHeaderNames())) {
            if (isPrestoHeader(name) || name.equalsIgnoreCase(COOKIE)) {
                for (String value : list(servletRequest.getHeaders(name))) {
                    requestBuilder.addHeader(name, value);
                }
            } else if (name.equalsIgnoreCase(USER_AGENT)) {
                for (String value : list(servletRequest.getHeaders(name))) {
                    requestBuilder.addHeader(name, "[Presto Proxy] " + value);
                }
            }
        }

        Request request = requestBuilder.setPreserveAuthorizationOnRedirect(true).build();

        ListenableFuture<Response> future = executeHttp(request).transform(responseBuilder::apply, executor)
                .catching(ProxyException.class, e -> handleProxyException(request, e), directExecutor());

        setupAsyncResponse(asyncResponse, future);
    }

    private Response buildResponse(UriInfo uriInfo, ProxyResponse response) {
        byte[] body = rewriteResponse(response.getBody(), uri -> rewriteUri(uriInfo, uri));
        return responseWithHeaders(Response.ok(body), response);
    }

    private String rewriteUri(UriInfo uriInfo, String uri) {
        return uriInfo.getAbsolutePathBuilder().replacePath("/v1/proxy").queryParam("uri", uri)
                .queryParam("hmac", hmac.hashString(uri, UTF_8)).build().toString();
    }

    private void setupAsyncResponse(AsyncResponse asyncResponse, ListenableFuture<Response> future) {
        bindAsyncResponse(asyncResponse, future, executor).withTimeout(ASYNC_TIMEOUT,
                () -> Response.status(BAD_GATEWAY).type(TEXT_PLAIN_TYPE)
                        .entity("Request to remote Presto server timed out after" + ASYNC_TIMEOUT).build());
    }

    private FluentFuture<ProxyResponse> executeHttp(Request request) {
        return FluentFuture.from(httpClient.executeAsync(request, new ProxyResponseHandler()));
    }

    private void setupBearerToken(HttpServletRequest servletRequest, Request.Builder requestBuilder) {
        if (!jwtHandler.isConfigured()) {
            return;
        }

        X509Certificate[] certs = (X509Certificate[]) servletRequest.getAttribute(X509_ATTRIBUTE);
        if ((certs == null) || (certs.length == 0)) {
            throw badRequest(FORBIDDEN, "No TLS certificate present for request");
        }
        String principal = certs[0].getSubjectX500Principal().getName();

        String accessToken = jwtHandler.getBearerToken(principal);
        requestBuilder.addHeader(AUTHORIZATION, "Bearer " + accessToken);
    }

    private static <T> T handleProxyException(Request request, ProxyException e) {
        log.warn(e, "Proxy request failed: %s %s", request.getMethod(), request.getUri());
        throw badRequest(BAD_GATEWAY, e.getMessage());
    }

    private static WebApplicationException badRequest(Status status, String message) {
        throw new WebApplicationException(Response.status(status).type(TEXT_PLAIN_TYPE).entity(message).build());
    }

    private static boolean isPrestoHeader(String name) {
        return name.toLowerCase(ENGLISH).startsWith("x-presto-");
    }

    private static Response responseWithHeaders(ResponseBuilder builder, ProxyResponse response) {
        response.getHeaders().asMap().forEach((headerName, value) -> {
            String name = headerName.toString();
            if (isPrestoHeader(name) || name.equalsIgnoreCase(SET_COOKIE)) {
                builder.header(name, value);
            }
        });
        return builder.build();
    }

    private static byte[] rewriteResponse(byte[] input, Function<String, String> uriRewriter) {
        try {
            JsonParser parser = JSON_FACTORY.createParser(input);
            ByteArrayOutputStream out = new ByteArrayOutputStream(input.length * 2);
            JsonGenerator generator = JSON_FACTORY.createGenerator(out);

            JsonToken token = parser.nextToken();
            if (token != START_OBJECT) {
                throw invalidJson("bad start token: " + token);
            }
            generator.copyCurrentEvent(parser);

            while (true) {
                token = parser.nextToken();
                if (token == null) {
                    throw invalidJson("unexpected end of stream");
                }

                if (token == END_OBJECT) {
                    generator.copyCurrentEvent(parser);
                    break;
                }

                if (token == FIELD_NAME) {
                    String name = parser.getValueAsString();
                    if (!"nextUri".equals(name) && !"partialCancelUri".equals(name)) {
                        generator.copyCurrentStructure(parser);
                        continue;
                    }

                    token = parser.nextToken();
                    if (token != VALUE_STRING) {
                        throw invalidJson(format("bad %s token: %s", name, token));
                    }
                    String value = parser.getValueAsString();

                    value = uriRewriter.apply(value);
                    generator.writeStringField(name, value);
                    continue;
                }

                throw invalidJson("unexpected token: " + token);
            }

            token = parser.nextToken();
            if (token != null) {
                throw invalidJson("unexpected token after object close: " + token);
            }

            generator.close();
            return out.toByteArray();
        } catch (IOException e) {
            throw new ProxyException(e);
        }
    }

    private static IOException invalidJson(String message) {
        return new IOException("Invalid JSON response from remote Presto server: " + message);
    }

    private static byte[] loadSharedSecret(File file) {
        try {
            return Base64.getMimeDecoder().decode(readAllBytes(file.toPath()));
        } catch (IOException | IllegalArgumentException e) {
            throw new RuntimeException("Failed to load shared secret file: " + file, e);
        }
    }
}