Skip to content

Commit

Permalink
Remove OAuth2AuthorizationRequest when too many per registration id
Browse files Browse the repository at this point in the history
  • Loading branch information
candrews committed Apr 13, 2021
1 parent 18978e6 commit 5f720ff
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import java.time.Duration;
import java.time.Instant;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.stream.Collectors;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
Expand Down Expand Up @@ -59,7 +61,7 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository

private Duration authorizationRequestTimeToLive = Duration.ofSeconds(120);

private int maxActiveAuthorizationRequestsPerSession = 10;
private int maxActiveAuthorizationRequestsPerRegistrationIdPerSession = 3;

@Override
public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
Expand Down Expand Up @@ -87,10 +89,14 @@ public void saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationReq
Map<String, OAuth2AuthorizationRequestReference> authorizationRequests = this.getAuthorizationRequests(request);
authorizationRequests.put(state, new OAuth2AuthorizationRequestReference(authorizationRequest,
this.clock.instant().plus(this.authorizationRequestTimeToLive)));
if (authorizationRequests.size() > this.maxActiveAuthorizationRequestsPerSession) {
authorizationRequests.entrySet().stream()
.sorted((e, f) -> e.getValue().expiresAt.compareTo(f.getValue().expiresAt)).findFirst()
.map(Entry::getKey).ifPresent(authorizationRequests::remove);
for (String registrationId : authorizationRequests.values().stream().map((r) -> r.getRegistrationId())
.distinct().collect(Collectors.toList())) {
List<OAuth2AuthorizationRequestReference> references = authorizationRequests.values().stream()
.filter((r) -> Objects.equals(registrationId, r.getRegistrationId())).collect(Collectors.toList());
if (references.size() > this.maxActiveAuthorizationRequestsPerRegistrationIdPerSession) {
references.stream().sorted((a, b) -> a.expiresAt.compareTo(b.expiresAt)).findFirst()
.map((r) -> r.getState()).ifPresent(authorizationRequests::remove);
}
}
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
}
Expand Down Expand Up @@ -177,14 +183,16 @@ void setAuthorizationRequestTimeToLive(Duration authorizationRequestTimeToLive)

/**
* Sets the maximum number of {@link OAuth2AuthorizationRequest} that can be
* stored/active for a session. If the maximum number are present in a session when an
* attempt is made to save another one, then the oldest will be removed.
* stored/active per registration id for a session. If the maximum number are present
* in a session when an attempt is made to save another one, then the oldest will be
* removed.
* @param maxActiveAuthorizationRequestsPerSession must not be negative.
*/
void setMaxActiveAuthorizationRequestsPerSession(int maxActiveAuthorizationRequestsPerSession) {
Assert.state(maxActiveAuthorizationRequestsPerSession > 0,
"maxActiveAuthorizationRequestsPerSession must be greater than zero");
this.maxActiveAuthorizationRequestsPerSession = maxActiveAuthorizationRequestsPerSession;
void setMaxActiveAuthorizationRequestsPerRegistrationIdPerSession(
int maxActiveAuthorizationRequestsPerRegistrationIdPerSession) {
Assert.state(maxActiveAuthorizationRequestsPerRegistrationIdPerSession > 0,
"maxActiveAuthorizationRequestsPerRegistrationIdPerSession must be greater than zero");
this.maxActiveAuthorizationRequestsPerRegistrationIdPerSession = maxActiveAuthorizationRequestsPerRegistrationIdPerSession;
}

private static final class OAuth2AuthorizationRequestReference implements Serializable {
Expand All @@ -202,6 +210,14 @@ private OAuth2AuthorizationRequestReference(OAuth2AuthorizationRequest authoriza
this.authorizationRequest = authorizationRequest;
}

private String getRegistrationId() {
return this.authorizationRequest.getAttribute(OAuth2ParameterNames.REGISTRATION_ID);
}

private String getState() {
return this.authorizationRequest.getState();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -270,17 +271,21 @@ public void removeAuthorizationRequestWhenExpired() {

@Test
public void removeOldestAuthorizationRequestWhenMoreThanMax() {
this.authorizationRequestRepository.setMaxActiveAuthorizationRequestsPerSession(2);
String registrationId = "registration-id-1";
this.authorizationRequestRepository.setMaxActiveAuthorizationRequestsPerRegistrationIdPerSession(2);
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
String state1 = "state-1122";
OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1).build();
OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1)
.attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, registrationId)).build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);
String state2 = "state-3344";
OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2).build();
OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2)
.attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, registrationId)).build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);
String state3 = "state-4455";
OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3).build();
OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3)
.attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, registrationId)).build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response);
request.addParameter(OAuth2ParameterNames.STATE, state1);
OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository
Expand All @@ -298,6 +303,42 @@ public void removeOldestAuthorizationRequestWhenMoreThanMax() {
assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3);
}

@Test
public void doNotremoveOldestAuthorizationRequestWhenLessThanMax() {
this.authorizationRequestRepository.setMaxActiveAuthorizationRequestsPerRegistrationIdPerSession(2);
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
String state1 = "state-1122";
OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1)
.attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, "registration-id-1"))
.build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);
String state2 = "state-3344";
OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2)
.attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, "registration-id-2"))
.build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);
String state3 = "state-4455";
OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3)
.attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, "registration-id-3"))
.build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response);
request.addParameter(OAuth2ParameterNames.STATE, state1);
OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository
.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest1).isEqualTo(authorizationRequest1);
request.removeParameter(OAuth2ParameterNames.STATE);
request.addParameter(OAuth2ParameterNames.STATE, state2);
OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository
.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2);
request.removeParameter(OAuth2ParameterNames.STATE);
request.addParameter(OAuth2ParameterNames.STATE, state3);
OAuth2AuthorizationRequest loadedAuthorizationRequest3 = this.authorizationRequestRepository
.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3);
}

private OAuth2AuthorizationRequest.Builder createAuthorizationRequest() {
return OAuth2AuthorizationRequest.authorizationCode().authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id-1234").state("state-1234");
Expand Down

0 comments on commit 5f720ff

Please sign in to comment.