org.springframework.web.socket.sockjs.client.AbstractSockJsIntegrationTests.java Source code

Java tutorial

Introduction

Here is the source code for org.springframework.web.socket.sockjs.client.AbstractSockJsIntegrationTests.java

Source

/*
 * Copyright 2002-2015 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.web.socket.sockjs.client;

import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.function.BooleanSupplier;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.After;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.tests.Assume;
import org.springframework.tests.TestGroup;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.WebSocketTestServer;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.RequestUpgradeStrategy;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

/**
 * Abstract base class for integration tests using the
 * {@link org.springframework.web.socket.sockjs.client.SockJsClient SockJsClient}
 * against actual SockJS server endpoints.
 *
 * @author Rossen Stoyanchev
 * @author Sam Brannen
 */
public abstract class AbstractSockJsIntegrationTests {

    @Rule
    public final TestName testName = new TestName();

    protected Log logger = LogFactory.getLog(getClass());

    private SockJsClient sockJsClient;

    private WebSocketTestServer server;

    private AnnotationConfigWebApplicationContext wac;

    private TestFilter testFilter;

    private String baseUrl;

    @BeforeClass
    public static void performanceTestGroupAssumption() throws Exception {
        Assume.group(TestGroup.PERFORMANCE);
    }

    @Before
    public void setup() throws Exception {
        logger.debug("Setting up '" + this.testName.getMethodName() + "'");
        this.testFilter = new TestFilter();

        this.wac = new AnnotationConfigWebApplicationContext();
        this.wac.register(TestConfig.class, upgradeStrategyConfigClass());

        this.server = createWebSocketTestServer();
        this.server.setup();
        this.server.deployConfig(this.wac, this.testFilter);
        this.server.start();

        this.wac.setServletContext(this.server.getServletContext());
        this.wac.refresh();

        this.baseUrl = "http://localhost:" + this.server.getPort();
    }

    @After
    public void teardown() throws Exception {
        try {
            this.sockJsClient.stop();
        } catch (Throwable ex) {
            logger.error("Failed to stop SockJsClient", ex);
        }
        try {
            this.server.undeployConfig();
        } catch (Throwable t) {
            logger.error("Failed to undeploy application config", t);
        }
        try {
            this.server.stop();
        } catch (Throwable t) {
            logger.error("Failed to stop server", t);
        }
        try {
            this.wac.close();
        } catch (Throwable t) {
            logger.error("Failed to close WebApplicationContext", t);
        }
    }

    protected abstract Class<?> upgradeStrategyConfigClass();

    protected abstract WebSocketTestServer createWebSocketTestServer();

    protected abstract Transport createWebSocketTransport();

    protected abstract AbstractXhrTransport createXhrTransport();

    protected void initSockJsClient(Transport... transports) {
        this.sockJsClient = new SockJsClient(Arrays.asList(transports));
        this.sockJsClient.start();
    }

    @Test
    public void echoWebSocket() throws Exception {
        testEcho(100, createWebSocketTransport(), null);
    }

    @Test
    public void echoXhrStreaming() throws Exception {
        testEcho(100, createXhrTransport(), null);
    }

    @Test
    public void echoXhr() throws Exception {
        AbstractXhrTransport xhrTransport = createXhrTransport();
        xhrTransport.setXhrStreamingDisabled(true);
        testEcho(100, xhrTransport, null);
    }

    // SPR-13254

    @Test
    public void echoXhrWithHeaders() throws Exception {
        AbstractXhrTransport xhrTransport = createXhrTransport();
        xhrTransport.setXhrStreamingDisabled(true);

        WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
        headers.add("auth", "123");
        testEcho(10, xhrTransport, headers);

        for (Map.Entry<String, HttpHeaders> entry : this.testFilter.requests.entrySet()) {
            HttpHeaders httpHeaders = entry.getValue();
            assertEquals("No auth header for: " + entry.getKey(), "123", httpHeaders.getFirst("auth"));
        }
    }

    @Test
    public void receiveOneMessageWebSocket() throws Exception {
        testReceiveOneMessage(createWebSocketTransport(), null);
    }

    @Test
    public void receiveOneMessageXhrStreaming() throws Exception {
        testReceiveOneMessage(createXhrTransport(), null);
    }

    @Test
    public void receiveOneMessageXhr() throws Exception {
        AbstractXhrTransport xhrTransport = createXhrTransport();
        xhrTransport.setXhrStreamingDisabled(true);
        testReceiveOneMessage(xhrTransport, null);
    }

    @Test
    public void infoRequestFailure() throws Exception {
        TestClientHandler handler = new TestClientHandler();
        this.testFilter.sendErrorMap.put("/info", 500);
        CountDownLatch latch = new CountDownLatch(1);
        initSockJsClient(createWebSocketTransport());
        this.sockJsClient.doHandshake(handler, this.baseUrl + "/echo")
                .addCallback(new ListenableFutureCallback<WebSocketSession>() {
                    @Override
                    public void onSuccess(WebSocketSession result) {
                    }

                    @Override
                    public void onFailure(Throwable ex) {
                        latch.countDown();
                    }
                });
        assertTrue(latch.await(5000, TimeUnit.MILLISECONDS));
    }

    @Test
    public void fallbackAfterTransportFailure() throws Exception {
        this.testFilter.sendErrorMap.put("/websocket", 200);
        this.testFilter.sendErrorMap.put("/xhr_streaming", 500);
        TestClientHandler handler = new TestClientHandler();
        initSockJsClient(createWebSocketTransport(), createXhrTransport());
        WebSocketSession session = this.sockJsClient.doHandshake(handler, this.baseUrl + "/echo").get();
        assertEquals("Fallback didn't occur", XhrClientSockJsSession.class, session.getClass());
        TextMessage message = new TextMessage("message1");
        session.sendMessage(message);
        handler.awaitMessage(message, 5000);
    }

    @Test(timeout = 5000)
    public void fallbackAfterConnectTimeout() throws Exception {
        TestClientHandler clientHandler = new TestClientHandler();
        this.testFilter.sleepDelayMap.put("/xhr_streaming", 10000L);
        this.testFilter.sendErrorMap.put("/xhr_streaming", 503);
        initSockJsClient(createXhrTransport());
        this.sockJsClient.setConnectTimeoutScheduler(this.wac.getBean(ThreadPoolTaskScheduler.class));
        WebSocketSession clientSession = sockJsClient.doHandshake(clientHandler, this.baseUrl + "/echo").get();
        assertEquals("Fallback didn't occur", XhrClientSockJsSession.class, clientSession.getClass());
        TextMessage message = new TextMessage("message1");
        clientSession.sendMessage(message);
        clientHandler.awaitMessage(message, 5000);
        clientSession.close();
    }

    private void testEcho(int messageCount, Transport transport, WebSocketHttpHeaders headers) throws Exception {
        List<TextMessage> messages = new ArrayList<>();
        for (int i = 0; i < messageCount; i++) {
            messages.add(new TextMessage("m" + i));
        }
        TestClientHandler handler = new TestClientHandler();
        initSockJsClient(transport);
        URI url = new URI(this.baseUrl + "/echo");
        WebSocketSession session = this.sockJsClient.doHandshake(handler, headers, url).get();
        for (TextMessage message : messages) {
            session.sendMessage(message);
        }
        handler.awaitMessageCount(messageCount, 5000);
        for (TextMessage message : messages) {
            assertTrue("Message not received: " + message, handler.receivedMessages.remove(message));
        }
        assertEquals("Remaining messages: " + handler.receivedMessages, 0, handler.receivedMessages.size());
        session.close();
    }

    private void testReceiveOneMessage(Transport transport, WebSocketHttpHeaders headers) throws Exception {

        TestClientHandler clientHandler = new TestClientHandler();
        initSockJsClient(transport);
        this.sockJsClient.doHandshake(clientHandler, headers, new URI(this.baseUrl + "/test")).get();
        TestServerHandler serverHandler = this.wac.getBean(TestServerHandler.class);

        assertNotNull("afterConnectionEstablished should have been called", clientHandler.session);
        serverHandler.awaitSession(5000);

        TextMessage message = new TextMessage("message1");
        serverHandler.session.sendMessage(message);
        clientHandler.awaitMessage(message, 5000);
    }

    private static void awaitEvent(BooleanSupplier condition, long timeToWait, String description) {
        long timeToSleep = 200;
        for (int i = 0; i < Math.floor(timeToWait / timeToSleep); i++) {
            if (condition.getAsBoolean()) {
                return;
            }
            try {
                Thread.sleep(timeToSleep);
            } catch (InterruptedException e) {
                throw new IllegalStateException("Interrupted while waiting for " + description, e);
            }
        }
        throw new IllegalStateException("Timed out waiting for " + description);
    }

    @Configuration
    @EnableWebSocket
    static class TestConfig implements WebSocketConfigurer {

        @Autowired
        private RequestUpgradeStrategy upgradeStrategy;

        @Override
        public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
            HandshakeHandler handshakeHandler = new DefaultHandshakeHandler(this.upgradeStrategy);
            registry.addHandler(new EchoHandler(), "/echo").setHandshakeHandler(handshakeHandler).withSockJS();
            registry.addHandler(testServerHandler(), "/test").setHandshakeHandler(handshakeHandler).withSockJS();
        }

        @Bean
        public TestServerHandler testServerHandler() {
            return new TestServerHandler();
        }
    }

    private static class TestClientHandler extends TextWebSocketHandler {

        private final BlockingQueue<TextMessage> receivedMessages = new LinkedBlockingQueue<>();

        private volatile WebSocketSession session;

        private volatile Throwable transportError;

        @Override
        public void afterConnectionEstablished(WebSocketSession session) throws Exception {
            this.session = session;
        }

        @Override
        protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
            this.receivedMessages.add(message);
        }

        @Override
        public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
            this.transportError = exception;
        }

        public void awaitMessageCount(final int count, long timeToWait) throws Exception {
            awaitEvent(() -> receivedMessages.size() >= count, timeToWait,
                    count + " number of messages. Received so far: " + this.receivedMessages);
        }

        public void awaitMessage(TextMessage expected, long timeToWait) throws InterruptedException {
            TextMessage actual = this.receivedMessages.poll(timeToWait, TimeUnit.MILLISECONDS);
            if (actual != null) {
                assertEquals(expected, actual);
            } else if (this.transportError != null) {
                throw new AssertionError("Transport error", this.transportError);
            } else {
                fail("Timed out waiting for [" + expected + "]");
            }
        }
    }

    private static class EchoHandler extends TextWebSocketHandler {

        @Override
        protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
            session.sendMessage(message);
        }
    }

    private static class TestServerHandler extends TextWebSocketHandler {

        private WebSocketSession session;

        @Override
        public void afterConnectionEstablished(WebSocketSession session) throws Exception {
            this.session = session;
        }

        public WebSocketSession awaitSession(long timeToWait) throws InterruptedException {
            awaitEvent(() -> this.session != null, timeToWait, " session");
            return this.session;
        }
    }

    private static class TestFilter implements Filter {

        private final Map<String, HttpHeaders> requests = new HashMap<>();

        private final Map<String, Long> sleepDelayMap = new HashMap<>();

        private final Map<String, Integer> sendErrorMap = new HashMap<>();

        @Override
        public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
                throws IOException, ServletException {

            HttpServletRequest httpRequest = (HttpServletRequest) request;
            String uri = httpRequest.getRequestURI();
            HttpHeaders headers = new ServletServerHttpRequest(httpRequest).getHeaders();
            this.requests.put(uri, headers);

            for (String suffix : this.sleepDelayMap.keySet()) {
                if ((httpRequest).getRequestURI().endsWith(suffix)) {
                    try {
                        Thread.sleep(this.sleepDelayMap.get(suffix));
                        break;
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }
            for (String suffix : this.sendErrorMap.keySet()) {
                if ((httpRequest).getRequestURI().endsWith(suffix)) {
                    ((HttpServletResponse) response).sendError(this.sendErrorMap.get(suffix));
                    return;
                }
            }
            chain.doFilter(request, response);
        }

        @Override
        public void init(FilterConfig filterConfig) {
        }

        @Override
        public void destroy() {
        }
    }

}