From a54297bd55013dc0055ceaa26290bf32bc217204 Mon Sep 17 00:00:00 2001 From: sgibb Date: Thu, 23 May 2024 16:18:49 -0400 Subject: [PATCH] Puts client response input stream in a request attribute. Fixes gh-3405 --- .../gateway/server/mvc/common/MvcUtils.java | 5 +++ ...ClientHttpRequestFactoryProxyExchange.java | 12 +++++- .../mvc/handler/RestClientProxyExchange.java | 11 +++++- .../server/mvc/ServerMvcIntegrationTests.java | 38 +++++++++++++++++++ 4 files changed, 64 insertions(+), 2 deletions(-) diff --git a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/common/MvcUtils.java b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/common/MvcUtils.java index 1afb1c52cc..51a1edd1f6 100644 --- a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/common/MvcUtils.java +++ b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/common/MvcUtils.java @@ -57,6 +57,11 @@ public abstract class MvcUtils { */ public static final String CACHED_REQUEST_BODY_ATTR = qualify("cachedRequestBody"); + /** + * Client response input stream key. + */ + public static final String CLIENT_RESPONSE_INPUT_STREAM_ATTR = qualify("cachedClientResponseBody"); + /** * CircuitBreaker execution exception attribute name. */ diff --git a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/handler/ClientHttpRequestFactoryProxyExchange.java b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/handler/ClientHttpRequestFactoryProxyExchange.java index fd8b85fa7f..0a35a59cfd 100644 --- a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/handler/ClientHttpRequestFactoryProxyExchange.java +++ b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/handler/ClientHttpRequestFactoryProxyExchange.java @@ -17,8 +17,10 @@ package org.springframework.cloud.gateway.server.mvc.handler; import java.io.IOException; +import java.io.InputStream; import java.io.UncheckedIOException; +import org.springframework.cloud.gateway.server.mvc.common.MvcUtils; import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpResponse; @@ -41,10 +43,18 @@ public ServerResponse exchange(Request request) { // copy body from request to clientHttpRequest StreamUtils.copy(request.getServerRequest().servletRequest().getInputStream(), clientHttpRequest.getBody()); ClientHttpResponse clientHttpResponse = clientHttpRequest.execute(); + InputStream body = clientHttpResponse.getBody(); + // put the body input stream in a request attribute so filters can read it. + MvcUtils.putAttribute(request.getServerRequest(), MvcUtils.CLIENT_RESPONSE_INPUT_STREAM_ATTR, body); ServerResponse serverResponse = GatewayServerResponse.status(clientHttpResponse.getStatusCode()) .build((req, httpServletResponse) -> { try (clientHttpResponse) { - StreamUtils.copy(clientHttpResponse.getBody(), httpServletResponse.getOutputStream()); + // get input stream from request attribute in case it was + // modified. + InputStream inputStream = MvcUtils.getAttribute(request.getServerRequest(), + MvcUtils.CLIENT_RESPONSE_INPUT_STREAM_ATTR); + // copy body from request to clientHttpRequest + StreamUtils.copy(inputStream, httpServletResponse.getOutputStream()); } return null; }); diff --git a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/handler/RestClientProxyExchange.java b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/handler/RestClientProxyExchange.java index 1f24992472..598dba051f 100644 --- a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/handler/RestClientProxyExchange.java +++ b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/handler/RestClientProxyExchange.java @@ -17,8 +17,10 @@ package org.springframework.cloud.gateway.server.mvc.handler; import java.io.IOException; +import java.io.InputStream; import java.io.OutputStream; +import org.springframework.cloud.gateway.server.mvc.common.MvcUtils; import org.springframework.http.client.ClientHttpResponse; import org.springframework.util.StreamUtils; import org.springframework.web.client.RestClient; @@ -45,11 +47,18 @@ private static int copyBody(Request request, OutputStream outputStream) throws I } private static ServerResponse doExchange(Request request, ClientHttpResponse clientResponse) throws IOException { + InputStream body = clientResponse.getBody(); + // put the body input stream in a request attribute so filters can read it. + MvcUtils.putAttribute(request.getServerRequest(), MvcUtils.CLIENT_RESPONSE_INPUT_STREAM_ATTR, body); ServerResponse serverResponse = GatewayServerResponse.status(clientResponse.getStatusCode()) .build((req, httpServletResponse) -> { try (clientResponse) { + // get input stream from request attribute in case it was + // modified. + InputStream inputStream = MvcUtils.getAttribute(request.getServerRequest(), + MvcUtils.CLIENT_RESPONSE_INPUT_STREAM_ATTR); // copy body from request to clientHttpRequest - StreamUtils.copy(clientResponse.getBody(), httpServletResponse.getOutputStream()); + StreamUtils.copy(inputStream, httpServletResponse.getOutputStream()); } return null; }); diff --git a/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/ServerMvcIntegrationTests.java b/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/ServerMvcIntegrationTests.java index 45fabc15bd..06bfb43745 100644 --- a/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/ServerMvcIntegrationTests.java +++ b/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/ServerMvcIntegrationTests.java @@ -16,7 +16,9 @@ package org.springframework.cloud.gateway.server.mvc; +import java.io.ByteArrayInputStream; import java.io.IOException; +import java.io.InputStream; import java.net.URI; import java.nio.charset.StandardCharsets; import java.time.Duration; @@ -75,6 +77,7 @@ import org.springframework.test.context.ContextConfiguration; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.util.StreamUtils; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; @@ -696,6 +699,15 @@ public void queryParamWithSpecialCharactersWorks() { }); } + @Test + public void clientResponseBodyAttributeWorks() { + restClient.get().uri("/anything/readresponsebody").header("X-Foo", "fooval").exchange().expectStatus().isOk() + .expectBody(Map.class).consumeWith(res -> { + Map headers = getMap(res.getResponseBody(), "headers"); + assertThat(headers).containsEntry("X-Foo", "FOOVAL"); + }); + } + @SpringBootConfiguration @EnableAutoConfiguration @LoadBalancerClient(name = "httpbin", configuration = TestLoadBalancerConfig.Httpbin.class) @@ -1289,6 +1301,32 @@ public RouterFunction gatewayRouterFunctionsQuery() { // @formatter:on } + @Bean + public RouterFunction gatewayRouterFunctionsReadResponseBody() { + // @formatter:off + return route("testClientResponseBodyAttribute") + .GET("/anything/readresponsebody", http()) + .before(new HttpbinUriResolver()) + .after((request, response) -> { + Object o = request.attributes().get(MvcUtils.CLIENT_RESPONSE_INPUT_STREAM_ATTR); + if (o instanceof InputStream) { + try { + byte[] bytes = StreamUtils.copyToByteArray((InputStream) o); + String s = new String(bytes, StandardCharsets.UTF_8); + String replace = s.replace("fooval", "FOOVAL"); + ByteArrayInputStream bais = new ByteArrayInputStream(replace.getBytes()); + request.attributes().put(MvcUtils.CLIENT_RESPONSE_INPUT_STREAM_ATTR, bais); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + return response; + }) + .build(); + // @formatter:on + } + @Bean public FilterRegistrationBean myFilter() { FilterRegistrationBean reg = new FilterRegistrationBean<>(new MyFilter());