org.springframework.boot.web.reactive.server.AbstractReactiveWebServerFactoryTests.java Source code

Java tutorial

Introduction

Here is the source code for org.springframework.boot.web.reactive.server.AbstractReactiveWebServerFactoryTests.java

Source

/*
 * Copyright 2012-2019 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.boot.web.reactive.server;

import java.io.File;
import java.io.FileInputStream;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.time.Duration;
import java.util.Arrays;

import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLException;
import javax.net.ssl.X509KeyManager;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import org.junit.After;
import org.junit.Rule;
import org.junit.Test;
import reactor.core.publisher.Mono;
import reactor.netty.NettyPipeline;
import reactor.netty.http.client.HttpClient;
import reactor.test.StepVerifier;

import org.springframework.boot.testsupport.rule.OutputCapture;
import org.springframework.boot.web.server.Compression;
import org.springframework.boot.web.server.Ssl;
import org.springframework.boot.web.server.WebServer;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.reactive.ReactorClientHttpConnector;
import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.SocketUtils;
import org.springframework.util.unit.DataSize;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.WebClient;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/**
 * Base for testing classes that extends {@link AbstractReactiveWebServerFactory}.
 *
 * @author Brian Clozel
 */
public abstract class AbstractReactiveWebServerFactoryTests {

    @Rule
    public OutputCapture output = new OutputCapture();

    protected WebServer webServer;

    @After
    public void tearDown() {
        if (this.webServer != null) {
            try {
                this.webServer.stop();
            } catch (Exception ex) {
                // Ignore
            }
        }
    }

    protected abstract AbstractReactiveWebServerFactory getFactory();

    @Test
    public void specificPort() {
        AbstractReactiveWebServerFactory factory = getFactory();
        int specificPort = SocketUtils.findAvailableTcpPort(41000);
        factory.setPort(specificPort);
        this.webServer = factory.getWebServer(new EchoHandler());
        this.webServer.start();
        Mono<String> result = getWebClient().build().post().uri("/test").contentType(MediaType.TEXT_PLAIN)
                .body(BodyInserters.fromObject("Hello World")).exchange()
                .flatMap((response) -> response.bodyToMono(String.class));
        assertThat(result.block(Duration.ofSeconds(30))).isEqualTo("Hello World");
        assertThat(this.webServer.getPort()).isEqualTo(specificPort);
    }

    @Test
    public void basicSslFromClassPath() {
        testBasicSslWithKeyStore("classpath:test.jks", "password");
    }

    @Test
    public void basicSslFromFileSystem() {
        testBasicSslWithKeyStore("src/test/resources/test.jks", "password");
    }

    protected final void testBasicSslWithKeyStore(String keyStore, String keyPassword) {
        AbstractReactiveWebServerFactory factory = getFactory();
        Ssl ssl = new Ssl();
        ssl.setKeyStore(keyStore);
        ssl.setKeyPassword(keyPassword);
        factory.setSsl(ssl);
        this.webServer = factory.getWebServer(new EchoHandler());
        this.webServer.start();
        ReactorClientHttpConnector connector = buildTrustAllSslConnector();
        WebClient client = WebClient.builder().baseUrl("https://localhost:" + this.webServer.getPort())
                .clientConnector(connector).build();
        Mono<String> result = client.post().uri("/test").contentType(MediaType.TEXT_PLAIN)
                .body(BodyInserters.fromObject("Hello World")).exchange()
                .flatMap((response) -> response.bodyToMono(String.class));
        assertThat(result.block(Duration.ofSeconds(30))).isEqualTo("Hello World");
    }

    protected ReactorClientHttpConnector buildTrustAllSslConnector() {
        SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
                .trustManager(InsecureTrustManagerFactory.INSTANCE);
        HttpClient client = HttpClient.create().wiretap(true)
                .secure((sslContextSpec) -> sslContextSpec.sslContext(builder));
        return new ReactorClientHttpConnector(client);
    }

    @Test
    public void sslWantsClientAuthenticationSucceedsWithClientCertificate() throws Exception {
        Ssl ssl = new Ssl();
        ssl.setClientAuth(Ssl.ClientAuth.WANT);
        ssl.setKeyStore("classpath:test.jks");
        ssl.setKeyPassword("password");
        ssl.setTrustStore("classpath:test.jks");
        testClientAuthSuccess(ssl, buildTrustAllSslWithClientKeyConnector());
    }

    @Test
    public void sslWantsClientAuthenticationSucceedsWithoutClientCertificate() {
        Ssl ssl = new Ssl();
        ssl.setClientAuth(Ssl.ClientAuth.WANT);
        ssl.setKeyStore("classpath:test.jks");
        ssl.setKeyPassword("password");
        ssl.setTrustStore("classpath:test.jks");
        testClientAuthSuccess(ssl, buildTrustAllSslConnector());
    }

    protected ReactorClientHttpConnector buildTrustAllSslWithClientKeyConnector() throws Exception {
        KeyStore clientKeyStore = KeyStore.getInstance(KeyStore.getDefaultType());
        clientKeyStore.load(new FileInputStream(new File("src/test/resources/test.jks")), "secret".toCharArray());
        KeyManagerFactory clientKeyManagerFactory = KeyManagerFactory
                .getInstance(KeyManagerFactory.getDefaultAlgorithm());
        clientKeyManagerFactory.init(clientKeyStore, "password".toCharArray());
        for (KeyManager keyManager : clientKeyManagerFactory.getKeyManagers()) {
            if (keyManager instanceof X509KeyManager) {
                X509KeyManager x509KeyManager = (X509KeyManager) keyManager;
                PrivateKey privateKey = x509KeyManager.getPrivateKey("spring-boot");
                if (privateKey != null) {
                    X509Certificate[] certificateChain = x509KeyManager.getCertificateChain("spring-boot");
                    SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
                            .trustManager(InsecureTrustManagerFactory.INSTANCE)
                            .keyManager(privateKey, certificateChain);
                    HttpClient client = HttpClient.create().wiretap(true)
                            .secure((sslContextSpec) -> sslContextSpec.sslContext(builder));
                    return new ReactorClientHttpConnector(client);
                }
            }
        }
        throw new IllegalStateException("Key with alias 'spring-boot' not found");
    }

    protected void testClientAuthSuccess(Ssl sslConfiguration, ReactorClientHttpConnector clientConnector) {
        AbstractReactiveWebServerFactory factory = getFactory();
        factory.setSsl(sslConfiguration);
        this.webServer = factory.getWebServer(new EchoHandler());
        this.webServer.start();
        WebClient client = WebClient.builder().baseUrl("https://localhost:" + this.webServer.getPort())
                .clientConnector(clientConnector).build();
        Mono<String> result = client.post().uri("/test").contentType(MediaType.TEXT_PLAIN)
                .body(BodyInserters.fromObject("Hello World")).exchange()
                .flatMap((response) -> response.bodyToMono(String.class));
        assertThat(result.block(Duration.ofSeconds(30))).isEqualTo("Hello World");
    }

    @Test
    public void sslNeedsClientAuthenticationSucceedsWithClientCertificate() throws Exception {
        Ssl ssl = new Ssl();
        ssl.setClientAuth(Ssl.ClientAuth.NEED);
        ssl.setKeyStore("classpath:test.jks");
        ssl.setKeyPassword("password");
        ssl.setTrustStore("classpath:test.jks");
        testClientAuthSuccess(ssl, buildTrustAllSslWithClientKeyConnector());
    }

    @Test
    public void sslNeedsClientAuthenticationFailsWithoutClientCertificate() {
        Ssl ssl = new Ssl();
        ssl.setClientAuth(Ssl.ClientAuth.NEED);
        ssl.setKeyStore("classpath:test.jks");
        ssl.setKeyPassword("password");
        ssl.setTrustStore("classpath:test.jks");
        testClientAuthFailure(ssl, buildTrustAllSslConnector());
    }

    protected void testClientAuthFailure(Ssl sslConfiguration, ReactorClientHttpConnector clientConnector) {
        AbstractReactiveWebServerFactory factory = getFactory();
        factory.setSsl(sslConfiguration);
        this.webServer = factory.getWebServer(new EchoHandler());
        this.webServer.start();
        WebClient client = WebClient.builder().baseUrl("https://localhost:" + this.webServer.getPort())
                .clientConnector(clientConnector).build();
        Mono<String> result = client.post().uri("/test").contentType(MediaType.TEXT_PLAIN)
                .body(BodyInserters.fromObject("Hello World")).exchange()
                .flatMap((response) -> response.bodyToMono(String.class));
        StepVerifier.create(result).expectError(SSLException.class).verify(Duration.ofSeconds(10));
    }

    protected WebClient.Builder getWebClient() {
        return getWebClient(HttpClient.create().wiretap(true));
    }

    protected WebClient.Builder getWebClient(HttpClient client) {
        InetSocketAddress address = new InetSocketAddress(this.webServer.getPort());
        String baseUrl = "http://" + address.getHostString() + ":" + address.getPort();
        return WebClient.builder().clientConnector(new ReactorClientHttpConnector(client)).baseUrl(baseUrl);
    }

    @Test
    public void compressionOfResponseToGetRequest() {
        WebClient client = prepareCompressionTest();
        ResponseEntity<Void> response = client.get().exchange().flatMap((res) -> res.toEntity(Void.class))
                .block(Duration.ofSeconds(30));
        assertResponseIsCompressed(response);
    }

    @Test
    public void compressionOfResponseToPostRequest() {
        WebClient client = prepareCompressionTest();
        ResponseEntity<Void> response = client.post().exchange().flatMap((res) -> res.toEntity(Void.class))
                .block(Duration.ofSeconds(30));
        assertResponseIsCompressed(response);
    }

    @Test
    public void noCompressionForSmallResponse() {
        Compression compression = new Compression();
        compression.setEnabled(true);
        compression.setMinResponseSize(DataSize.ofBytes(3001));
        WebClient client = prepareCompressionTest(compression);
        ResponseEntity<Void> response = client.get().exchange().flatMap((res) -> res.toEntity(Void.class))
                .block(Duration.ofSeconds(30));
        assertResponseIsNotCompressed(response);
    }

    @Test
    public void noCompressionForMimeType() {
        Compression compression = new Compression();
        compression.setMimeTypes(new String[] { "application/json" });
        WebClient client = prepareCompressionTest(compression);
        ResponseEntity<Void> response = client.get().exchange().flatMap((res) -> res.toEntity(Void.class))
                .block(Duration.ofSeconds(30));
        assertResponseIsNotCompressed(response);
    }

    @Test
    public void noCompressionForUserAgent() {
        Compression compression = new Compression();
        compression.setEnabled(true);
        compression.setExcludedUserAgents(new String[] { "testUserAgent" });
        WebClient client = prepareCompressionTest(compression);
        ResponseEntity<Void> response = client.get().header("User-Agent", "testUserAgent").exchange()
                .flatMap((res) -> res.toEntity(Void.class)).block(Duration.ofSeconds(30));
        assertResponseIsNotCompressed(response);
    }

    @Test
    public void whenSslIsEnabledAndNoKeyStoreIsConfiguredThenServerFailsToStart() {
        assertThatThrownBy(() -> testBasicSslWithKeyStore(null, null))
                .hasMessageContaining("Could not load key store 'null'");
    }

    protected WebClient prepareCompressionTest() {
        Compression compression = new Compression();
        compression.setEnabled(true);
        return prepareCompressionTest(compression);

    }

    protected WebClient prepareCompressionTest(Compression compression) {
        AbstractReactiveWebServerFactory factory = getFactory();
        factory.setCompression(compression);
        this.webServer = factory.getWebServer(new CharsHandler(3000, MediaType.TEXT_PLAIN));
        this.webServer.start();

        HttpClient client = HttpClient.create().wiretap(true).compress(true)
                .tcpConfiguration((tcpClient) -> tcpClient.doOnConnected(
                        (connection) -> connection.channel().pipeline().addBefore(NettyPipeline.HttpDecompressor,
                                "CompressionTest", new CompressionDetectionHandler())));
        return getWebClient(client).build();
    }

    protected void assertResponseIsCompressed(ResponseEntity<Void> response) {
        assertThat(response.getHeaders().getFirst("X-Test-Compressed")).isEqualTo("true");
    }

    protected void assertResponseIsNotCompressed(ResponseEntity<Void> response) {
        assertThat(response.getHeaders().keySet()).doesNotContain("X-Test-Compressed");
    }

    protected void assertForwardHeaderIsUsed(AbstractReactiveWebServerFactory factory) {
        this.webServer = factory.getWebServer(new XForwardedHandler());
        this.webServer.start();
        String body = getWebClient().build().get().header("X-Forwarded-Proto", "https").retrieve()
                .bodyToMono(String.class).block(Duration.ofSeconds(30));
        assertThat(body).isEqualTo("https");
    }

    protected static class EchoHandler implements HttpHandler {

        public EchoHandler() {
        }

        @Override
        public Mono<Void> handle(ServerHttpRequest request, ServerHttpResponse response) {
            response.setStatusCode(HttpStatus.OK);
            return response.writeWith(request.getBody());
        }

    }

    protected static class CompressionDetectionHandler extends ChannelInboundHandlerAdapter {

        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg) {
            if (msg instanceof HttpResponse) {
                HttpResponse response = (HttpResponse) msg;
                boolean compressed = response.headers().contains(HttpHeaderNames.CONTENT_ENCODING, "gzip", true);
                if (compressed) {
                    response.headers().set("X-Test-Compressed", "true");
                }
            }
            ctx.fireChannelRead(msg);
        }

    }

    protected static class CharsHandler implements HttpHandler {

        private static final DefaultDataBufferFactory factory = new DefaultDataBufferFactory();

        private final DataBuffer bytes;

        private final MediaType mediaType;

        public CharsHandler(int contentSize, MediaType mediaType) {
            char[] chars = new char[contentSize];
            Arrays.fill(chars, 'F');
            this.bytes = factory.wrap(new String(chars).getBytes(StandardCharsets.UTF_8));
            this.mediaType = mediaType;
        }

        @Override
        public Mono<Void> handle(ServerHttpRequest request, ServerHttpResponse response) {
            response.setStatusCode(HttpStatus.OK);
            response.getHeaders().setContentType(this.mediaType);
            response.getHeaders().setContentLength(this.bytes.readableByteCount());
            return response.writeWith(Mono.just(this.bytes));
        }

    }

    protected static class XForwardedHandler implements HttpHandler {

        @Override
        public Mono<Void> handle(ServerHttpRequest request, ServerHttpResponse response) {
            String scheme = request.getURI().getScheme();
            DataBufferFactory bufferFactory = new DefaultDataBufferFactory();
            DataBuffer buffer = bufferFactory.wrap(scheme.getBytes(StandardCharsets.UTF_8));
            return response.writeWith(Mono.just(buffer));
        }

    }

}