Skip to content

Commit

Permalink
TestDispatcherServlet unwraps to find mock request
Browse files Browse the repository at this point in the history
Issue: SPR-16695
  • Loading branch information
rstoyanchev committed Apr 6, 2018
1 parent d3acf45 commit 6deee3e
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,8 @@
import javax.servlet.http.HttpServletResponse;

import org.springframework.mock.web.MockAsyncContext;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.util.Assert;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.context.request.async.CallableProcessingInterceptorAdapter;
Expand All @@ -35,6 +37,7 @@
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.util.WebUtils;

/**
* A sub-class of {@code DispatcherServlet} that saves the result in an
Expand Down Expand Up @@ -64,8 +67,24 @@ protected void service(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {

registerAsyncResultInterceptors(request);

super.service(request, response);
initAsyncDispatchLatch(request);

if (request.getAsyncContext() != null) {
MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class);
Assert.notNull(mockRequest, "Expected MockHttpServletRequest");
MockAsyncContext mockAsyncContext = ((MockAsyncContext) mockRequest.getAsyncContext());
Assert.notNull(mockAsyncContext, "MockAsyncContext not found. Did request wrapper not delegate startAsync?");

final CountDownLatch dispatchLatch = new CountDownLatch(1);
mockAsyncContext.addDispatchHandler(new Runnable() {
@Override
public void run() {
dispatchLatch.countDown();
}
});
getMvcResult(request).setAsyncDispatchLatch(dispatchLatch);
}
}

private void registerAsyncResultInterceptors(final HttpServletRequest request) {
Expand All @@ -84,19 +103,6 @@ public <T> void postProcess(NativeWebRequest r, DeferredResult<T> result, Object
});
}

private void initAsyncDispatchLatch(HttpServletRequest request) {
if (request.getAsyncContext() != null) {
final CountDownLatch dispatchLatch = new CountDownLatch(1);
((MockAsyncContext) request.getAsyncContext()).addDispatchHandler(new Runnable() {
@Override
public void run() {
dispatchLatch.countDown();
}
});
getMvcResult(request).setAsyncDispatchLatch(dispatchLatch);
}
}

protected DefaultMvcResult getMvcResult(ServletRequest request) {
return (DefaultMvcResult) request.getAttribute(MockMvc.MVC_RESULT_ATTRIBUTE);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,9 +18,15 @@

import java.io.IOException;
import java.security.Principal;
import java.util.concurrent.CompletableFuture;
import javax.servlet.AsyncContext;
import javax.servlet.AsyncListener;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
Expand All @@ -29,17 +35,23 @@

import org.junit.Test;

import org.springframework.http.MediaType;
import org.springframework.stereotype.Controller;
import org.springframework.test.web.Person;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.validation.Errors;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.mvc.support.RedirectAttributes;

import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request;
import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*;

/**
Expand Down Expand Up @@ -107,6 +119,22 @@ public void filterWrapsRequestResponse() throws Exception {
.andExpect(model().attribute("principal", WrappingRequestResponseFilter.PRINCIPAL_NAME));
}

@Test // SPR-16695
public void filterWrapsRequestResponseAndPerformsAsyncDispatch() throws Exception {
MockMvc mockMvc = standaloneSetup(new PersonController())
.addFilters(new WrappingRequestResponseFilter())
.build();

MvcResult mvcResult = mockMvc.perform(get("/persons/1").accept(MediaType.APPLICATION_JSON))
.andExpect(request().asyncStarted())
.andExpect(request().asyncResult(new Person("Lukas")))
.andReturn();

mockMvc.perform(asyncDispatch(mvcResult))
.andExpect(status().isOk())
.andExpect(content().string("{\"name\":\"Lukas\",\"someDouble\":0.0,\"someBoolean\":false}"));
}


@Controller
private static class PersonController {
Expand All @@ -129,6 +157,12 @@ public ModelAndView user(Principal principal) {
public String forward() {
return "forward:/persons";
}

@GetMapping("persons/{id}")
@ResponseBody
public CompletableFuture<Person> getPerson() {
return CompletableFuture.completedFuture(new Person("Lukas"));
}
}

private class ContinueFilter extends OncePerRequestFilter {
Expand All @@ -149,15 +183,20 @@ protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {

filterChain.doFilter(new HttpServletRequestWrapper(request) {

@Override
public Principal getUserPrincipal() {
return new Principal() {
@Override
public String getName() {
return PRINCIPAL_NAME;
}
};
return () -> PRINCIPAL_NAME;
}

// Like Spring Security does in HttpServlet3RequestFactory..

@Override
public AsyncContext getAsyncContext() {
return super.getAsyncContext() != null ?
new AsyncContextWrapper(super.getAsyncContext()) : null;
}

}, new HttpServletResponseWrapper(response));
}
}
Expand All @@ -170,4 +209,80 @@ protected void doFilterInternal(HttpServletRequest request,
response.sendRedirect("/login");
}
}


private static class AsyncContextWrapper implements AsyncContext {

private final AsyncContext delegate;

public AsyncContextWrapper(AsyncContext delegate) {
this.delegate = delegate;
}

@Override
public ServletRequest getRequest() {
return this.delegate.getRequest();
}

@Override
public ServletResponse getResponse() {
return this.delegate.getResponse();
}

@Override
public boolean hasOriginalRequestAndResponse() {
return this.delegate.hasOriginalRequestAndResponse();
}

@Override
public void dispatch() {
this.delegate.dispatch();
}

@Override
public void dispatch(String path) {
this.delegate.dispatch(path);
}

@Override
public void dispatch(ServletContext context, String path) {
this.delegate.dispatch(context, path);
}

@Override
public void complete() {
this.delegate.complete();
}

@Override
public void start(Runnable run) {
this.delegate.start(run);
}

@Override
public void addListener(AsyncListener listener) {
this.delegate.addListener(listener);
}

@Override
public void addListener(AsyncListener listener, ServletRequest req, ServletResponse res) {
this.delegate.addListener(listener, req, res);
}

@Override
public <T extends AsyncListener> T createListener(Class<T> clazz) throws ServletException {
return this.delegate.createListener(clazz);
}

@Override
public void setTimeout(long timeout) {
this.delegate.setTimeout(timeout);
}

@Override
public long getTimeout() {
return this.delegate.getTimeout();
}
}

}

0 comments on commit 6deee3e

Please sign in to comment.