io.reactivex.netty.contexts.http.ContextPropagationTest.java Source code

Java tutorial

Introduction

Here is the source code for io.reactivex.netty.contexts.http.ContextPropagationTest.java

Source

/*
 * Copyright 2014 Netflix, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.reactivex.netty.contexts.http;

import com.netflix.server.context.ContextSerializationException;
import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.logging.LogLevel;
import io.reactivex.netty.RxNetty;
import io.reactivex.netty.contexts.ContextKeySupplier;
import io.reactivex.netty.contexts.ContextsContainer;
import io.reactivex.netty.contexts.ContextsContainerImpl;
import io.reactivex.netty.contexts.MapBackedKeySupplier;
import io.reactivex.netty.contexts.RxContexts;
import io.reactivex.netty.contexts.TestContext;
import io.reactivex.netty.contexts.TestContextSerializer;
import io.reactivex.netty.protocol.http.client.HttpClient;
import io.reactivex.netty.protocol.http.client.HttpClientRequest;
import io.reactivex.netty.protocol.http.client.HttpClientResponse;
import io.reactivex.netty.protocol.http.server.HttpServer;
import io.reactivex.netty.protocol.http.server.HttpServerBuilder;
import io.reactivex.netty.protocol.http.server.HttpServerRequest;
import io.reactivex.netty.protocol.http.server.HttpServerResponse;
import io.reactivex.netty.protocol.http.server.RequestHandler;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import rx.Observable;
import rx.functions.Action0;
import rx.functions.Action1;
import rx.functions.Func1;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

import static io.reactivex.netty.contexts.ThreadLocalRequestCorrelator.getCurrentContextContainer;
import static io.reactivex.netty.contexts.ThreadLocalRequestCorrelator.getCurrentRequestId;

/**
 * @author Nitesh Kant
 */
public class ContextPropagationTest {

    public static final String CTX_3_FOUND_HEADER = "CTX_3_FOUND";

    private HttpServer<ByteBuf, ByteBuf> mockServer;
    private static final String REQUEST_ID_HEADER_NAME = "request_id";
    private static final String CTX_1_NAME = "ctx1";
    private static final String CTX_1_VAL = "ctx1_val";
    private static final String CTX_2_NAME = "ctx2";
    private static final TestContext CTX_2_VAL = new TestContext(CTX_2_NAME);

    @Before
    public void setUp() throws Exception {
        mockServer = RxNetty.newHttpServerBuilder(0, new RequestHandler<ByteBuf, ByteBuf>() {
            @Override
            public Observable<Void> handle(final HttpServerRequest<ByteBuf> request,
                    HttpServerResponse<ByteBuf> response) {
                final String requestId = request.getHeaders().get(REQUEST_ID_HEADER_NAME);
                if (null == requestId) {
                    System.err.println("Request Id not found.");
                    return Observable.error(new AssertionError("Request Id not found in mock server."));
                }
                response.getHeaders().add(REQUEST_ID_HEADER_NAME, requestId);
                ContextKeySupplier supplier = new ContextKeySupplier() {
                    @Override
                    public String getContextValue(String key) {
                        return request.getHeaders().get(key);
                    }
                };
                ContextsContainer container = new ContextsContainerImpl(supplier);
                try {
                    String ctx1 = container.getContext(CTX_1_NAME);
                    TestContext ctx2 = container.getContext(CTX_2_NAME);
                    if (null != ctx1 && null != ctx2 && ctx1.equals(CTX_1_VAL) && ctx2.equals(CTX_2_VAL)) {
                        return response.writeStringAndFlush("Welcome!");
                    } else {
                        response.setStatus(HttpResponseStatus.BAD_REQUEST);
                        return response.writeStringAndFlush("Contexts not found or have wrong values.");
                    }
                } catch (ContextSerializationException e) {
                    return Observable.error(e);
                }
            }
        }).enableWireLogging(LogLevel.DEBUG).build();
        mockServer.start();
    }

    @After
    public void tearDown() throws Exception {
        mockServer.shutdown();
        mockServer.waitTillShutdown(1, TimeUnit.MINUTES);
    }

    @Test
    public void testEndToEnd() throws Exception {
        HttpServer<ByteBuf, ByteBuf> server = newTestServerBuilder(
                new Func1<HttpClient<ByteBuf, ByteBuf>, Observable<HttpClientResponse<ByteBuf>>>() {
                    @Override
                    public Observable<HttpClientResponse<ByteBuf>> call(HttpClient<ByteBuf, ByteBuf> client) {
                        return client.submit(HttpClientRequest.createGet("/"));
                    }
                }).enableWireLogging(LogLevel.ERROR).build().start();

        HttpClient<ByteBuf, ByteBuf> testClient = RxNetty
                .<ByteBuf, ByteBuf>newHttpClientBuilder("localhost", server.getServerPort())
                .enableWireLogging(LogLevel.DEBUG).build();

        String reqId = "testE2E";
        sendTestRequest(testClient, reqId);
    }

    @Test(expected = MockBackendRequestFailedException.class)
    public void testWithThreadSwitchNegative() throws Exception {
        HttpServer<ByteBuf, ByteBuf> server = newTestServerBuilder(
                new Func1<HttpClient<ByteBuf, ByteBuf>, Observable<HttpClientResponse<ByteBuf>>>() {
                    @Override
                    public Observable<HttpClientResponse<ByteBuf>> call(final HttpClient<ByteBuf, ByteBuf> client) {
                        return Observable.timer(1, TimeUnit.MILLISECONDS)
                                .flatMap(new Func1<Long, Observable<HttpClientResponse<ByteBuf>>>() {
                                    @Override
                                    public Observable<HttpClientResponse<ByteBuf>> call(Long aLong) {
                                        return client.submit(HttpClientRequest.createGet("/"));
                                    }
                                });
                    }
                }).build().start();

        HttpClient<ByteBuf, ByteBuf> testClient = RxNetty.createHttpClient("localhost", server.getServerPort());

        String reqId = "testWithThreadSwitchNegative";
        sendTestRequest(testClient, reqId);
    }

    @Test
    public void testWithThreadSwitch() throws Exception {
        final ExecutorService executor = Executors.newSingleThreadExecutor();
        HttpServer<ByteBuf, ByteBuf> server = newTestServerBuilder(
                new Func1<HttpClient<ByteBuf, ByteBuf>, Observable<HttpClientResponse<ByteBuf>>>() {
                    @Override
                    public Observable<HttpClientResponse<ByteBuf>> call(final HttpClient<ByteBuf, ByteBuf> client) {
                        Callable<HttpClientResponse<ByteBuf>> ctxAware = RxContexts.DEFAULT_CORRELATOR
                                .makeClosure(new Callable<HttpClientResponse<ByteBuf>>() {
                                    @Override
                                    public HttpClientResponse<ByteBuf> call() throws Exception {
                                        return client.submit(HttpClientRequest.createGet("/")).toBlocking().last();
                                    }
                                });
                        Future<HttpClientResponse<ByteBuf>> submit = executor.submit(ctxAware);
                        return Observable.from(submit);
                    }
                }).build().start();

        HttpClient<ByteBuf, ByteBuf> testClient = RxNetty.createHttpClient("localhost", server.getServerPort());

        String reqId = "testWithThreadSwitch";
        sendTestRequest(testClient, reqId);
    }

    @Test
    public void testWithPooledConnections() throws Exception {
        HttpClient<ByteBuf, ByteBuf> testClient = RxContexts
                .<ByteBuf, ByteBuf>newHttpClientBuilder("localhost", mockServer.getServerPort(),
                        REQUEST_ID_HEADER_NAME, RxContexts.DEFAULT_CORRELATOR)
                .withMaxConnections(1).withIdleConnectionsTimeoutMillis(100000).build();
        ContextsContainer container = new ContextsContainerImpl(new MapBackedKeySupplier());
        container.addContext(CTX_1_NAME, CTX_1_VAL);
        container.addContext(CTX_2_NAME, CTX_2_VAL, new TestContextSerializer());

        String reqId = "testWithPooledConnections";
        RxContexts.DEFAULT_CORRELATOR.onNewServerRequest(reqId, container);

        invokeMockServer(testClient, reqId, false);

        invokeMockServer(testClient, reqId, true);
    }

    @Test(expected = MockBackendRequestFailedException.class)
    public void testNoStateLeakOnThreadReuse() throws Exception {
        HttpClient<ByteBuf, ByteBuf> testClient = RxContexts
                .<ByteBuf, ByteBuf>newHttpClientBuilder("localhost", mockServer.getServerPort(),
                        REQUEST_ID_HEADER_NAME, RxContexts.DEFAULT_CORRELATOR)
                .withMaxConnections(1).withIdleConnectionsTimeoutMillis(100000).build();

        ContextsContainer container = new ContextsContainerImpl(new MapBackedKeySupplier());
        container.addContext(CTX_1_NAME, CTX_1_VAL);
        container.addContext(CTX_2_NAME, CTX_2_VAL, new TestContextSerializer());

        String reqId = "testNoStateLeakOnThreadReuse";
        RxContexts.DEFAULT_CORRELATOR.onNewServerRequest(reqId, container);

        try {
            invokeMockServer(testClient, reqId, true);
        } catch (MockBackendRequestFailedException e) {
            throw new AssertionError("First request to mock backend failed. Error: " + e.getMessage());
        }

        invokeMockServer(testClient, reqId, false);
    }

    private HttpServerBuilder<ByteBuf, ByteBuf> newTestServerBuilder(
            final Func1<HttpClient<ByteBuf, ByteBuf>, Observable<HttpClientResponse<ByteBuf>>> clientInvoker) {
        return RxContexts.newHttpServerBuilder(0, new RequestHandler<ByteBuf, ByteBuf>() {
            @Override
            public Observable<Void> handle(HttpServerRequest<ByteBuf> request,
                    final HttpServerResponse<ByteBuf> serverResponse) {
                String reqId = getCurrentRequestId();
                if (null == reqId) {
                    return Observable.error(new AssertionError("Request Id not found at server."));
                }
                ContextsContainer container = getCurrentContextContainer();
                if (null == container) {
                    return Observable.error(new AssertionError("Context container not found by server."));
                }
                container.addContext(CTX_1_NAME, CTX_1_VAL);
                container.addContext(CTX_2_NAME, CTX_2_VAL, new TestContextSerializer());

                HttpClient<ByteBuf, ByteBuf> client = RxContexts
                        .<ByteBuf, ByteBuf>newHttpClientBuilder("localhost", mockServer.getServerPort(),
                                REQUEST_ID_HEADER_NAME, RxContexts.DEFAULT_CORRELATOR)
                        .withMaxConnections(1).enableWireLogging(LogLevel.DEBUG).build();

                return clientInvoker.call(client)
                        .flatMap(new Func1<HttpClientResponse<ByteBuf>, Observable<Void>>() {
                            @Override
                            public Observable<Void> call(HttpClientResponse<ByteBuf> response) {
                                serverResponse.setStatus(response.getStatus());
                                return serverResponse.close(true);
                            }
                        });
            }
        }, REQUEST_ID_HEADER_NAME, RxContexts.DEFAULT_CORRELATOR);
    }

    private static void invokeMockServer(HttpClient<ByteBuf, ByteBuf> testClient, final String requestId,
            boolean finishServerProcessing) throws MockBackendRequestFailedException, InterruptedException {
        try {
            sendTestRequest(testClient, requestId);
        } finally {
            if (finishServerProcessing) {
                RxContexts.DEFAULT_CORRELATOR.onServerProcessingEnd(requestId);
                System.err.println("Sent server processing end callback to correlator.");
                RxContexts.DEFAULT_CORRELATOR.dumpThreadState(System.err);
            }
        }

        if (finishServerProcessing) {
            Assert.assertNull("Current request id not cleared from thread.", getCurrentRequestId());
            Assert.assertNull("Current context not cleared from thread.", getCurrentContextContainer());
        }
    }

    private static void sendTestRequest(HttpClient<ByteBuf, ByteBuf> testClient, final String requestId)
            throws MockBackendRequestFailedException, InterruptedException {
        System.err.println("Sending test request to mock server, with request id: " + requestId);
        RxContexts.DEFAULT_CORRELATOR.dumpThreadState(System.err);
        final CountDownLatch finishLatch = new CountDownLatch(1);
        final List<HttpClientResponse<ByteBuf>> responseHolder = new ArrayList<HttpClientResponse<ByteBuf>>();
        testClient.submit(HttpClientRequest.createGet("").withHeader(REQUEST_ID_HEADER_NAME, requestId))
                .finallyDo(new Action0() {
                    @Override
                    public void call() {
                        finishLatch.countDown();
                    }
                }).subscribe(new Action1<HttpClientResponse<ByteBuf>>() {
                    @Override
                    public void call(HttpClientResponse<ByteBuf> response) {
                        responseHolder.add(response);
                    }
                });

        finishLatch.await(1, TimeUnit.MINUTES);
        if (responseHolder.isEmpty()) {
            throw new AssertionError("Response not received.");
        }

        System.err.println("Received response from mock server, with request id: " + requestId + ", status: "
                + responseHolder.get(0).getStatus());

        HttpClientResponse<ByteBuf> response = responseHolder.get(0);

        if (response.getStatus().code() != HttpResponseStatus.OK.code()) {
            throw new MockBackendRequestFailedException(
                    "Test request failed. Status: " + response.getStatus().code());
        }

        String requestIdGot = response.getHeaders().get(REQUEST_ID_HEADER_NAME);

        if (!requestId.equals(requestId)) {
            throw new MockBackendRequestFailedException(
                    "Request Id not sent from mock server. Expected: " + requestId + ", got: " + requestIdGot);
        }
    }

    private static class MockBackendRequestFailedException extends Exception {

        private static final long serialVersionUID = 5033661188956567940L;

        private MockBackendRequestFailedException(String message) {
            super(message);
        }
    }
}