From 87de6cea1bff2e8700400f68186b014c1b75d32f Mon Sep 17 00:00:00 2001
From: Josh Cummings <3627351+jzheaux@users.noreply.github.com>
Date: Fri, 6 Dec 2024 15:13:48 -0700
Subject: [PATCH] Use Reactive JSON Encoder

Closes gh-16177
---
 .../web/server/HttpMessageConverters.java     | 65 +++++++++++++
 .../config/web/server/OAuth2ErrorEncoder.java | 95 +++++++++++++++++++
 .../OidcBackChannelLogoutWebFilter.java       | 33 +++----
 .../OidcBackChannelServerLogoutHandler.java   | 31 +++---
 4 files changed, 190 insertions(+), 34 deletions(-)
 create mode 100644 config/src/main/java/org/springframework/security/config/web/server/HttpMessageConverters.java
 create mode 100644 config/src/main/java/org/springframework/security/config/web/server/OAuth2ErrorEncoder.java

diff --git a/config/src/main/java/org/springframework/security/config/web/server/HttpMessageConverters.java b/config/src/main/java/org/springframework/security/config/web/server/HttpMessageConverters.java
new file mode 100644
index 00000000000..cda41044cac
--- /dev/null
+++ b/config/src/main/java/org/springframework/security/config/web/server/HttpMessageConverters.java
@@ -0,0 +1,65 @@
+/*
+ * Copyright 2002-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.web.server;
+
+import org.springframework.http.converter.GenericHttpMessageConverter;
+import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.http.converter.json.GsonHttpMessageConverter;
+import org.springframework.http.converter.json.JsonbHttpMessageConverter;
+import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
+import org.springframework.util.ClassUtils;
+
+/**
+ * Utility methods for {@link HttpMessageConverter}'s.
+ *
+ * @author Joe Grandja
+ * @author luamas
+ * @since 5.1
+ */
+final class HttpMessageConverters {
+
+	private static final boolean jackson2Present;
+
+	private static final boolean gsonPresent;
+
+	private static final boolean jsonbPresent;
+
+	static {
+		ClassLoader classLoader = HttpMessageConverters.class.getClassLoader();
+		jackson2Present = ClassUtils.isPresent("com.fasterxml.jackson.databind.ObjectMapper", classLoader)
+				&& ClassUtils.isPresent("com.fasterxml.jackson.core.JsonGenerator", classLoader);
+		gsonPresent = ClassUtils.isPresent("com.google.gson.Gson", classLoader);
+		jsonbPresent = ClassUtils.isPresent("jakarta.json.bind.Jsonb", classLoader);
+	}
+
+	private HttpMessageConverters() {
+	}
+
+	static GenericHttpMessageConverter<Object> getJsonMessageConverter() {
+		if (jackson2Present) {
+			return new MappingJackson2HttpMessageConverter();
+		}
+		if (gsonPresent) {
+			return new GsonHttpMessageConverter();
+		}
+		if (jsonbPresent) {
+			return new JsonbHttpMessageConverter();
+		}
+		return null;
+	}
+
+}
diff --git a/config/src/main/java/org/springframework/security/config/web/server/OAuth2ErrorEncoder.java b/config/src/main/java/org/springframework/security/config/web/server/OAuth2ErrorEncoder.java
new file mode 100644
index 00000000000..784782344cf
--- /dev/null
+++ b/config/src/main/java/org/springframework/security/config/web/server/OAuth2ErrorEncoder.java
@@ -0,0 +1,95 @@
+/*
+ * Copyright 2002-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.web.server;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import org.jetbrains.annotations.NotNull;
+import org.reactivestreams.Publisher;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+
+import org.springframework.core.ResolvableType;
+import org.springframework.core.io.buffer.DataBuffer;
+import org.springframework.core.io.buffer.DataBufferFactory;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpOutputMessage;
+import org.springframework.http.MediaType;
+import org.springframework.http.codec.HttpMessageEncoder;
+import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.util.MimeType;
+
+class OAuth2ErrorEncoder implements HttpMessageEncoder<OAuth2Error> {
+
+	private final HttpMessageConverter<Object> messageConverter = HttpMessageConverters.getJsonMessageConverter();
+
+	@NotNull
+	@Override
+	public List<MediaType> getStreamingMediaTypes() {
+		return List.of();
+	}
+
+	@Override
+	public boolean canEncode(ResolvableType elementType, MimeType mimeType) {
+		return getEncodableMimeTypes().contains(mimeType);
+	}
+
+	@NotNull
+	@Override
+	public Flux<DataBuffer> encode(Publisher<? extends OAuth2Error> error, DataBufferFactory bufferFactory,
+			ResolvableType elementType, MimeType mimeType, Map<String, Object> hints) {
+		return Mono.from(error).flatMap((data) -> {
+			ByteArrayHttpOutputMessage bytes = new ByteArrayHttpOutputMessage();
+			try {
+				this.messageConverter.write(data, MediaType.APPLICATION_JSON, bytes);
+				return Mono.just(bytes.getBody().toByteArray());
+			}
+			catch (IOException ex) {
+				return Mono.error(ex);
+			}
+		}).map(bufferFactory::wrap).flux();
+	}
+
+	@NotNull
+	@Override
+	public List<MimeType> getEncodableMimeTypes() {
+		return List.of(MediaType.APPLICATION_JSON);
+	}
+
+	private static class ByteArrayHttpOutputMessage implements HttpOutputMessage {
+
+		private final ByteArrayOutputStream body = new ByteArrayOutputStream();
+
+		@NotNull
+		@Override
+		public ByteArrayOutputStream getBody() {
+			return this.body;
+		}
+
+		@NotNull
+		@Override
+		public HttpHeaders getHeaders() {
+			return new HttpHeaders();
+		}
+
+	}
+
+}
diff --git a/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutWebFilter.java b/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutWebFilter.java
index 74f5f32e687..8f1788c498f 100644
--- a/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutWebFilter.java
+++ b/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutWebFilter.java
@@ -16,16 +16,17 @@
 
 package org.springframework.security.config.web.server;
 
-import java.nio.charset.StandardCharsets;
+import java.util.Collections;
 
 import jakarta.servlet.http.HttpServletResponse;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
-import reactor.core.publisher.Flux;
 import reactor.core.publisher.Mono;
 
-import org.springframework.core.io.buffer.DataBuffer;
-import org.springframework.http.server.reactive.ServerHttpResponse;
+import org.springframework.core.ResolvableType;
+import org.springframework.http.MediaType;
+import org.springframework.http.codec.EncoderHttpMessageWriter;
+import org.springframework.http.codec.HttpMessageWriter;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.AuthenticationServiceException;
 import org.springframework.security.authentication.ReactiveAuthenticationManager;
@@ -62,6 +63,9 @@ class OidcBackChannelLogoutWebFilter implements WebFilter {
 
 	private ServerLogoutHandler logoutHandler = new OidcBackChannelServerLogoutHandler();
 
+	private final HttpMessageWriter<OAuth2Error> errorHttpMessageConverter = new EncoderHttpMessageWriter<>(
+			new OAuth2ErrorEncoder());
+
 	/**
 	 * Construct an {@link OidcBackChannelLogoutWebFilter}
 	 * @param authenticationConverter the {@link AuthenticationConverter} for deriving
@@ -84,7 +88,7 @@ public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
 			if (ex instanceof AuthenticationServiceException) {
 				return Mono.error(ex);
 			}
-			return handleAuthenticationFailure(exchange.getResponse(), ex).then(Mono.empty());
+			return handleAuthenticationFailure(exchange, ex).then(Mono.empty());
 		})
 			.switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
 			.flatMap(this.authenticationManager::authenticate)
@@ -93,7 +97,7 @@ public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
 				if (ex instanceof AuthenticationServiceException) {
 					return Mono.error(ex);
 				}
-				return handleAuthenticationFailure(exchange.getResponse(), ex).then(Mono.empty());
+				return handleAuthenticationFailure(exchange, ex).then(Mono.empty());
 			})
 			.flatMap((authentication) -> {
 				WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain);
@@ -101,19 +105,12 @@ public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
 			});
 	}
 
-	private Mono<Void> handleAuthenticationFailure(ServerHttpResponse response, Exception ex) {
+	private Mono<Void> handleAuthenticationFailure(ServerWebExchange exchange, Exception ex) {
 		this.logger.debug("Failed to process OIDC Back-Channel Logout", ex);
-		response.setRawStatusCode(HttpServletResponse.SC_BAD_REQUEST);
-		OAuth2Error error = oauth2Error(ex);
-		byte[] bytes = String.format("""
-				{
-					"error_code": "%s",
-					"error_description": "%s",
-					"error_uri: "%s"
-				}
-				""", error.getErrorCode(), error.getDescription(), error.getUri()).getBytes(StandardCharsets.UTF_8);
-		DataBuffer buffer = response.bufferFactory().wrap(bytes);
-		return response.writeWith(Flux.just(buffer));
+		exchange.getResponse().setRawStatusCode(HttpServletResponse.SC_BAD_REQUEST);
+		return this.errorHttpMessageConverter.write(Mono.just(oauth2Error(ex)), ResolvableType.forClass(Object.class),
+				ResolvableType.forClass(Object.class), MediaType.APPLICATION_JSON, exchange.getRequest(),
+				exchange.getResponse(), Collections.emptyMap());
 	}
 
 	private OAuth2Error oauth2Error(Exception ex) {
diff --git a/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelServerLogoutHandler.java b/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelServerLogoutHandler.java
index 5312a6da7c4..c0c1e73bc61 100644
--- a/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelServerLogoutHandler.java
+++ b/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelServerLogoutHandler.java
@@ -16,8 +16,8 @@
 
 package org.springframework.security.config.web.server;
 
-import java.nio.charset.StandardCharsets;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -25,14 +25,15 @@
 import jakarta.servlet.http.HttpServletResponse;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
-import reactor.core.publisher.Flux;
 import reactor.core.publisher.Mono;
 
-import org.springframework.core.io.buffer.DataBuffer;
+import org.springframework.core.ResolvableType;
 import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
 import org.springframework.http.ResponseEntity;
+import org.springframework.http.codec.EncoderHttpMessageWriter;
+import org.springframework.http.codec.HttpMessageWriter;
 import org.springframework.http.server.reactive.ServerHttpRequest;
-import org.springframework.http.server.reactive.ServerHttpResponse;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutToken;
 import org.springframework.security.oauth2.client.oidc.server.session.InMemoryReactiveOidcSessionRegistry;
@@ -44,6 +45,7 @@
 import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler;
 import org.springframework.util.Assert;
 import org.springframework.web.reactive.function.client.WebClient;
+import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.util.UriComponents;
 import org.springframework.web.util.UriComponentsBuilder;
 
@@ -63,6 +65,9 @@ final class OidcBackChannelServerLogoutHandler implements ServerLogoutHandler {
 
 	private ReactiveOidcSessionRegistry sessionRegistry = new InMemoryReactiveOidcSessionRegistry();
 
+	private final HttpMessageWriter<OAuth2Error> errorHttpMessageConverter = new EncoderHttpMessageWriter<>(
+			new OAuth2ErrorEncoder());
+
 	private WebClient web = WebClient.create();
 
 	private String logoutUri = "{baseScheme}://localhost{basePort}/logout";
@@ -97,7 +102,7 @@ public Mono<Void> logout(WebFilterExchange exchange, Authentication authenticati
 						totalCount.intValue()));
 			}
 			if (!list.isEmpty()) {
-				return handleLogoutFailure(exchange.getExchange().getResponse(), oauth2Error(list));
+				return handleLogoutFailure(exchange.getExchange(), oauth2Error(list));
 			}
 			else {
 				return Mono.empty();
@@ -148,17 +153,11 @@ private OAuth2Error oauth2Error(Collection<?> errors) {
 				"https://openid.net/specs/openid-connect-backchannel-1_0.html#Validation");
 	}
 
-	private Mono<Void> handleLogoutFailure(ServerHttpResponse response, OAuth2Error error) {
-		response.setRawStatusCode(HttpServletResponse.SC_BAD_REQUEST);
-		byte[] bytes = String.format("""
-				{
-					"error_code": "%s",
-					"error_description": "%s",
-					"error_uri: "%s"
-				}
-				""", error.getErrorCode(), error.getDescription(), error.getUri()).getBytes(StandardCharsets.UTF_8);
-		DataBuffer buffer = response.bufferFactory().wrap(bytes);
-		return response.writeWith(Flux.just(buffer));
+	private Mono<Void> handleLogoutFailure(ServerWebExchange exchange, OAuth2Error error) {
+		exchange.getResponse().setRawStatusCode(HttpServletResponse.SC_BAD_REQUEST);
+		return this.errorHttpMessageConverter.write(Mono.just(error), ResolvableType.forClass(Object.class),
+				ResolvableType.forClass(Object.class), MediaType.APPLICATION_JSON, exchange.getRequest(),
+				exchange.getResponse(), Collections.emptyMap());
 	}
 
 	/**