Skip to content

Commit

Permalink
Puts client response input stream in a request attribute.
Browse files Browse the repository at this point in the history
Fixes gh-3405
  • Loading branch information
spencergibb committed May 23, 2024
1 parent 58c8441 commit a54297b
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ public abstract class MvcUtils {
*/
public static final String CACHED_REQUEST_BODY_ATTR = qualify("cachedRequestBody");

/**
* Client response input stream key.
*/
public static final String CLIENT_RESPONSE_INPUT_STREAM_ATTR = qualify("cachedClientResponseBody");

/**
* CircuitBreaker execution exception attribute name.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
package org.springframework.cloud.gateway.server.mvc.handler;

import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;

import org.springframework.cloud.gateway.server.mvc.common.MvcUtils;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpResponse;
Expand All @@ -41,10 +43,18 @@ public ServerResponse exchange(Request request) {
// copy body from request to clientHttpRequest
StreamUtils.copy(request.getServerRequest().servletRequest().getInputStream(), clientHttpRequest.getBody());
ClientHttpResponse clientHttpResponse = clientHttpRequest.execute();
InputStream body = clientHttpResponse.getBody();
// put the body input stream in a request attribute so filters can read it.
MvcUtils.putAttribute(request.getServerRequest(), MvcUtils.CLIENT_RESPONSE_INPUT_STREAM_ATTR, body);
ServerResponse serverResponse = GatewayServerResponse.status(clientHttpResponse.getStatusCode())
.build((req, httpServletResponse) -> {
try (clientHttpResponse) {
StreamUtils.copy(clientHttpResponse.getBody(), httpServletResponse.getOutputStream());
// get input stream from request attribute in case it was
// modified.
InputStream inputStream = MvcUtils.getAttribute(request.getServerRequest(),
MvcUtils.CLIENT_RESPONSE_INPUT_STREAM_ATTR);
// copy body from request to clientHttpRequest
StreamUtils.copy(inputStream, httpServletResponse.getOutputStream());
}
return null;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
package org.springframework.cloud.gateway.server.mvc.handler;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

import org.springframework.cloud.gateway.server.mvc.common.MvcUtils;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.StreamUtils;
import org.springframework.web.client.RestClient;
Expand All @@ -45,11 +47,18 @@ private static int copyBody(Request request, OutputStream outputStream) throws I
}

private static ServerResponse doExchange(Request request, ClientHttpResponse clientResponse) throws IOException {
InputStream body = clientResponse.getBody();
// put the body input stream in a request attribute so filters can read it.
MvcUtils.putAttribute(request.getServerRequest(), MvcUtils.CLIENT_RESPONSE_INPUT_STREAM_ATTR, body);
ServerResponse serverResponse = GatewayServerResponse.status(clientResponse.getStatusCode())
.build((req, httpServletResponse) -> {
try (clientResponse) {
// get input stream from request attribute in case it was
// modified.
InputStream inputStream = MvcUtils.getAttribute(request.getServerRequest(),
MvcUtils.CLIENT_RESPONSE_INPUT_STREAM_ATTR);
// copy body from request to clientHttpRequest
StreamUtils.copy(clientResponse.getBody(), httpServletResponse.getOutputStream());
StreamUtils.copy(inputStream, httpServletResponse.getOutputStream());
}
return null;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

package org.springframework.cloud.gateway.server.mvc;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
Expand Down Expand Up @@ -75,6 +77,7 @@
import org.springframework.test.context.ContextConfiguration;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StreamUtils;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
Expand Down Expand Up @@ -696,6 +699,15 @@ public void queryParamWithSpecialCharactersWorks() {
});
}

@Test
public void clientResponseBodyAttributeWorks() {
restClient.get().uri("/anything/readresponsebody").header("X-Foo", "fooval").exchange().expectStatus().isOk()
.expectBody(Map.class).consumeWith(res -> {
Map<String, Object> headers = getMap(res.getResponseBody(), "headers");
assertThat(headers).containsEntry("X-Foo", "FOOVAL");
});
}

@SpringBootConfiguration
@EnableAutoConfiguration
@LoadBalancerClient(name = "httpbin", configuration = TestLoadBalancerConfig.Httpbin.class)
Expand Down Expand Up @@ -1289,6 +1301,32 @@ public RouterFunction<ServerResponse> gatewayRouterFunctionsQuery() {
// @formatter:on
}

@Bean
public RouterFunction<ServerResponse> gatewayRouterFunctionsReadResponseBody() {
// @formatter:off
return route("testClientResponseBodyAttribute")
.GET("/anything/readresponsebody", http())
.before(new HttpbinUriResolver())
.after((request, response) -> {
Object o = request.attributes().get(MvcUtils.CLIENT_RESPONSE_INPUT_STREAM_ATTR);
if (o instanceof InputStream) {
try {
byte[] bytes = StreamUtils.copyToByteArray((InputStream) o);
String s = new String(bytes, StandardCharsets.UTF_8);
String replace = s.replace("fooval", "FOOVAL");
ByteArrayInputStream bais = new ByteArrayInputStream(replace.getBytes());
request.attributes().put(MvcUtils.CLIENT_RESPONSE_INPUT_STREAM_ATTR, bais);
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
return response;
})
.build();
// @formatter:on
}

@Bean
public FilterRegistrationBean myFilter() {
FilterRegistrationBean<MyFilter> reg = new FilterRegistrationBean<>(new MyFilter());
Expand Down

0 comments on commit a54297b

Please sign in to comment.