From 5f720ff473e6d971cca8ffa05f69de237fa94a12 Mon Sep 17 00:00:00 2001 From: Craig Andrews Date: Tue, 13 Apr 2021 17:24:11 -0400 Subject: [PATCH] Remove OAuth2AuthorizationRequest when too many per registration id --- ...nOAuth2AuthorizationRequestRepository.java | 40 ++++++++++----- ...h2AuthorizationRequestRepositoryTests.java | 49 +++++++++++++++++-- 2 files changed, 73 insertions(+), 16 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java index d1f60efc89e..8a8d89c2286 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java @@ -21,8 +21,10 @@ import java.time.Duration; import java.time.Instant; import java.util.HashMap; +import java.util.List; import java.util.Map; -import java.util.Map.Entry; +import java.util.Objects; +import java.util.stream.Collectors; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -59,7 +61,7 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository private Duration authorizationRequestTimeToLive = Duration.ofSeconds(120); - private int maxActiveAuthorizationRequestsPerSession = 10; + private int maxActiveAuthorizationRequestsPerRegistrationIdPerSession = 3; @Override public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) { @@ -87,10 +89,14 @@ public void saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationReq Map authorizationRequests = this.getAuthorizationRequests(request); authorizationRequests.put(state, new OAuth2AuthorizationRequestReference(authorizationRequest, this.clock.instant().plus(this.authorizationRequestTimeToLive))); - if (authorizationRequests.size() > this.maxActiveAuthorizationRequestsPerSession) { - authorizationRequests.entrySet().stream() - .sorted((e, f) -> e.getValue().expiresAt.compareTo(f.getValue().expiresAt)).findFirst() - .map(Entry::getKey).ifPresent(authorizationRequests::remove); + for (String registrationId : authorizationRequests.values().stream().map((r) -> r.getRegistrationId()) + .distinct().collect(Collectors.toList())) { + List references = authorizationRequests.values().stream() + .filter((r) -> Objects.equals(registrationId, r.getRegistrationId())).collect(Collectors.toList()); + if (references.size() > this.maxActiveAuthorizationRequestsPerRegistrationIdPerSession) { + references.stream().sorted((a, b) -> a.expiresAt.compareTo(b.expiresAt)).findFirst() + .map((r) -> r.getState()).ifPresent(authorizationRequests::remove); + } } request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests); } @@ -177,14 +183,16 @@ void setAuthorizationRequestTimeToLive(Duration authorizationRequestTimeToLive) /** * Sets the maximum number of {@link OAuth2AuthorizationRequest} that can be - * stored/active for a session. If the maximum number are present in a session when an - * attempt is made to save another one, then the oldest will be removed. + * stored/active per registration id for a session. If the maximum number are present + * in a session when an attempt is made to save another one, then the oldest will be + * removed. * @param maxActiveAuthorizationRequestsPerSession must not be negative. */ - void setMaxActiveAuthorizationRequestsPerSession(int maxActiveAuthorizationRequestsPerSession) { - Assert.state(maxActiveAuthorizationRequestsPerSession > 0, - "maxActiveAuthorizationRequestsPerSession must be greater than zero"); - this.maxActiveAuthorizationRequestsPerSession = maxActiveAuthorizationRequestsPerSession; + void setMaxActiveAuthorizationRequestsPerRegistrationIdPerSession( + int maxActiveAuthorizationRequestsPerRegistrationIdPerSession) { + Assert.state(maxActiveAuthorizationRequestsPerRegistrationIdPerSession > 0, + "maxActiveAuthorizationRequestsPerRegistrationIdPerSession must be greater than zero"); + this.maxActiveAuthorizationRequestsPerRegistrationIdPerSession = maxActiveAuthorizationRequestsPerRegistrationIdPerSession; } private static final class OAuth2AuthorizationRequestReference implements Serializable { @@ -202,6 +210,14 @@ private OAuth2AuthorizationRequestReference(OAuth2AuthorizationRequest authoriza this.authorizationRequest = authorizationRequest; } + private String getRegistrationId() { + return this.authorizationRequest.getAttribute(OAuth2ParameterNames.REGISTRATION_ID); + } + + private String getState() { + return this.authorizationRequest.getState(); + } + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java index 7c3cc9c95f5..a5ad1155e8c 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java @@ -20,6 +20,7 @@ import java.time.Duration; import java.time.Instant; import java.time.ZoneId; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -270,17 +271,21 @@ public void removeAuthorizationRequestWhenExpired() { @Test public void removeOldestAuthorizationRequestWhenMoreThanMax() { - this.authorizationRequestRepository.setMaxActiveAuthorizationRequestsPerSession(2); + String registrationId = "registration-id-1"; + this.authorizationRequestRepository.setMaxActiveAuthorizationRequestsPerRegistrationIdPerSession(2); MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); String state1 = "state-1122"; - OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1).build(); + OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1) + .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, registrationId)).build(); this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response); String state2 = "state-3344"; - OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2).build(); + OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2) + .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, registrationId)).build(); this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response); String state3 = "state-4455"; - OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3).build(); + OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3) + .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, registrationId)).build(); this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response); request.addParameter(OAuth2ParameterNames.STATE, state1); OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository @@ -298,6 +303,42 @@ public void removeOldestAuthorizationRequestWhenMoreThanMax() { assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3); } + @Test + public void doNotremoveOldestAuthorizationRequestWhenLessThanMax() { + this.authorizationRequestRepository.setMaxActiveAuthorizationRequestsPerRegistrationIdPerSession(2); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + String state1 = "state-1122"; + OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1) + .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, "registration-id-1")) + .build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response); + String state2 = "state-3344"; + OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2) + .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, "registration-id-2")) + .build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response); + String state3 = "state-4455"; + OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3) + .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, "registration-id-3")) + .build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response); + request.addParameter(OAuth2ParameterNames.STATE, state1); + OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest1).isEqualTo(authorizationRequest1); + request.removeParameter(OAuth2ParameterNames.STATE); + request.addParameter(OAuth2ParameterNames.STATE, state2); + OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2); + request.removeParameter(OAuth2ParameterNames.STATE); + request.addParameter(OAuth2ParameterNames.STATE, state3); + OAuth2AuthorizationRequest loadedAuthorizationRequest3 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3); + } + private OAuth2AuthorizationRequest.Builder createAuthorizationRequest() { return OAuth2AuthorizationRequest.authorizationCode().authorizationUri("https://example.com/oauth2/authorize") .clientId("client-id-1234").state("state-1234");