diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java index d0a3a2599af8..111bd21a3639 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,8 @@ import javax.servlet.http.HttpServletResponse; import org.springframework.mock.web.MockAsyncContext; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.util.Assert; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.async.CallableProcessingInterceptorAdapter; @@ -35,6 +37,7 @@ import org.springframework.web.servlet.DispatcherServlet; import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.ModelAndView; +import org.springframework.web.util.WebUtils; /** * A sub-class of {@code DispatcherServlet} that saves the result in an @@ -64,8 +67,24 @@ protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { registerAsyncResultInterceptors(request); + super.service(request, response); - initAsyncDispatchLatch(request); + + if (request.getAsyncContext() != null) { + MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class); + Assert.notNull(mockRequest, "Expected MockHttpServletRequest"); + MockAsyncContext mockAsyncContext = ((MockAsyncContext) mockRequest.getAsyncContext()); + Assert.notNull(mockAsyncContext, "MockAsyncContext not found. Did request wrapper not delegate startAsync?"); + + final CountDownLatch dispatchLatch = new CountDownLatch(1); + mockAsyncContext.addDispatchHandler(new Runnable() { + @Override + public void run() { + dispatchLatch.countDown(); + } + }); + getMvcResult(request).setAsyncDispatchLatch(dispatchLatch); + } } private void registerAsyncResultInterceptors(final HttpServletRequest request) { @@ -84,19 +103,6 @@ public void postProcess(NativeWebRequest r, DeferredResult result, Object }); } - private void initAsyncDispatchLatch(HttpServletRequest request) { - if (request.getAsyncContext() != null) { - final CountDownLatch dispatchLatch = new CountDownLatch(1); - ((MockAsyncContext) request.getAsyncContext()).addDispatchHandler(new Runnable() { - @Override - public void run() { - dispatchLatch.countDown(); - } - }); - getMvcResult(request).setAsyncDispatchLatch(dispatchLatch); - } - } - protected DefaultMvcResult getMvcResult(ServletRequest request) { return (DefaultMvcResult) request.getAttribute(MockMvc.MVC_RESULT_ATTRIBUTE); } diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java index e122f717600b..c2a67192ad88 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,9 +18,15 @@ import java.io.IOException; import java.security.Principal; +import java.util.concurrent.CompletableFuture; +import javax.servlet.AsyncContext; +import javax.servlet.AsyncListener; import javax.servlet.Filter; import javax.servlet.FilterChain; +import javax.servlet.ServletContext; import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; @@ -29,17 +35,23 @@ import org.junit.Test; +import org.springframework.http.MediaType; import org.springframework.stereotype.Controller; import org.springframework.test.web.Person; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; import org.springframework.validation.Errors; +import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; +import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.mvc.support.RedirectAttributes; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request; import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; /** @@ -107,6 +119,22 @@ public void filterWrapsRequestResponse() throws Exception { .andExpect(model().attribute("principal", WrappingRequestResponseFilter.PRINCIPAL_NAME)); } + @Test // SPR-16695 + public void filterWrapsRequestResponseAndPerformsAsyncDispatch() throws Exception { + MockMvc mockMvc = standaloneSetup(new PersonController()) + .addFilters(new WrappingRequestResponseFilter()) + .build(); + + MvcResult mvcResult = mockMvc.perform(get("/persons/1").accept(MediaType.APPLICATION_JSON)) + .andExpect(request().asyncStarted()) + .andExpect(request().asyncResult(new Person("Lukas"))) + .andReturn(); + + mockMvc.perform(asyncDispatch(mvcResult)) + .andExpect(status().isOk()) + .andExpect(content().string("{\"name\":\"Lukas\",\"someDouble\":0.0,\"someBoolean\":false}")); + } + @Controller private static class PersonController { @@ -129,6 +157,12 @@ public ModelAndView user(Principal principal) { public String forward() { return "forward:/persons"; } + + @GetMapping("persons/{id}") + @ResponseBody + public CompletableFuture getPerson() { + return CompletableFuture.completedFuture(new Person("Lukas")); + } } private class ContinueFilter extends OncePerRequestFilter { @@ -149,15 +183,20 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { filterChain.doFilter(new HttpServletRequestWrapper(request) { + @Override public Principal getUserPrincipal() { - return new Principal() { - @Override - public String getName() { - return PRINCIPAL_NAME; - } - }; + return () -> PRINCIPAL_NAME; } + + // Like Spring Security does in HttpServlet3RequestFactory.. + + @Override + public AsyncContext getAsyncContext() { + return super.getAsyncContext() != null ? + new AsyncContextWrapper(super.getAsyncContext()) : null; + } + }, new HttpServletResponseWrapper(response)); } } @@ -170,4 +209,80 @@ protected void doFilterInternal(HttpServletRequest request, response.sendRedirect("/login"); } } + + + private static class AsyncContextWrapper implements AsyncContext { + + private final AsyncContext delegate; + + public AsyncContextWrapper(AsyncContext delegate) { + this.delegate = delegate; + } + + @Override + public ServletRequest getRequest() { + return this.delegate.getRequest(); + } + + @Override + public ServletResponse getResponse() { + return this.delegate.getResponse(); + } + + @Override + public boolean hasOriginalRequestAndResponse() { + return this.delegate.hasOriginalRequestAndResponse(); + } + + @Override + public void dispatch() { + this.delegate.dispatch(); + } + + @Override + public void dispatch(String path) { + this.delegate.dispatch(path); + } + + @Override + public void dispatch(ServletContext context, String path) { + this.delegate.dispatch(context, path); + } + + @Override + public void complete() { + this.delegate.complete(); + } + + @Override + public void start(Runnable run) { + this.delegate.start(run); + } + + @Override + public void addListener(AsyncListener listener) { + this.delegate.addListener(listener); + } + + @Override + public void addListener(AsyncListener listener, ServletRequest req, ServletResponse res) { + this.delegate.addListener(listener, req, res); + } + + @Override + public T createListener(Class clazz) throws ServletException { + return this.delegate.createListener(clazz); + } + + @Override + public void setTimeout(long timeout) { + this.delegate.setTimeout(timeout); + } + + @Override + public long getTimeout() { + return this.delegate.getTimeout(); + } + } + }