io.spring.initializr.web.test.MockMvcClientHttpRequestFactory.java Source code

Java tutorial

Introduction

Here is the source code for io.spring.initializr.web.test.MockMvcClientHttpRequestFactory.java

Source

/*
 * Copyright 2012-2017 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.spring.initializr.web.test;

import java.io.IOException;
import java.lang.reflect.Method;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import javax.servlet.RequestDispatcher;

import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.mock.http.client.MockClientHttpRequest;
import org.springframework.mock.http.client.MockClientHttpResponse;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.restdocs.snippet.Snippet;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.ResultActions;
import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder;
import org.springframework.util.Assert;

import static org.springframework.restdocs.mockmvc.MockMvcRestDocumentation.document;
import static org.springframework.restdocs.operation.preprocess.Preprocessors.preprocessResponse;
import static org.springframework.restdocs.operation.preprocess.Preprocessors.prettyPrint;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.request;

/**
 * @author Dave Syer
 */
public class MockMvcClientHttpRequestFactory implements ClientHttpRequestFactory {

    private final MockMvc mockMvc;

    private String label = "UNKNOWN";

    private List<String> fields = new ArrayList<>();

    public MockMvcClientHttpRequestFactory(MockMvc mockMvc) {
        Assert.notNull(mockMvc, "MockMvc must not be null");
        this.mockMvc = mockMvc;
    }

    @Override
    public ClientHttpRequest createRequest(final URI uri, final HttpMethod httpMethod) throws IOException {
        return new MockClientHttpRequest(httpMethod, uri) {
            @Override
            public ClientHttpResponse executeInternal() throws IOException {
                try {
                    MockHttpServletRequestBuilder requestBuilder = request(httpMethod, uri.toString());
                    requestBuilder.content(getBodyAsBytes());
                    requestBuilder.headers(getHeaders());
                    MockHttpServletResponse servletResponse = actions(requestBuilder).andReturn().getResponse();
                    HttpStatus status = HttpStatus.valueOf(servletResponse.getStatus());
                    if (status.value() >= 400) {
                        requestBuilder = request(HttpMethod.GET, "/error")
                                .requestAttr(RequestDispatcher.ERROR_STATUS_CODE, status.value())
                                .requestAttr(RequestDispatcher.ERROR_REQUEST_URI, uri.toString());
                        if (servletResponse.getErrorMessage() != null) {
                            requestBuilder.requestAttr(RequestDispatcher.ERROR_MESSAGE,
                                    servletResponse.getErrorMessage());
                        }
                        // Overwrites the snippets from the first request
                        servletResponse = actions(requestBuilder).andReturn().getResponse();
                    }
                    byte[] body = servletResponse.getContentAsByteArray();
                    HttpHeaders headers = getResponseHeaders(servletResponse);
                    MockClientHttpResponse clientResponse = new MockClientHttpResponse(body, status);
                    clientResponse.getHeaders().putAll(headers);
                    return clientResponse;
                } catch (Exception ex) {
                    throw new IllegalStateException(ex);
                }
            }

        };
    }

    private ResultActions actions(MockHttpServletRequestBuilder requestBuilder) throws Exception {
        ResultActions actions = MockMvcClientHttpRequestFactory.this.mockMvc.perform(requestBuilder);
        List<Snippet> snippets = new ArrayList<>();
        for (String field : this.fields) {
            snippets.add(new ResponseFieldSnippet(field));
        }
        actions.andDo(document(this.label, preprocessResponse(prettyPrint()), snippets.toArray(new Snippet[0])));
        this.fields = new ArrayList<>();
        return actions;
    }

    private HttpHeaders getResponseHeaders(MockHttpServletResponse response) {
        HttpHeaders headers = new HttpHeaders();
        for (String name : response.getHeaderNames()) {
            List<String> values = response.getHeaders(name);
            for (String value : values) {
                headers.add(name, value);
            }
        }
        return headers;
    }

    public void setTest(Class<?> testClass, Method testMethod) {
        this.label = testMethod.getName();
    }

    public void setFields(String... fields) {
        this.fields = Arrays.asList(fields);
    }

}