Skip to content

Commit

Permalink
HttpSessionOAuth2AuthorizationRequestRepository: store one request by…
Browse files Browse the repository at this point in the history
… default

Add setAllowMultipleAuthorizationRequests allowing applications to
revert to the previous functionality should they need to do so.

Closes gh-5145
Intentionally regresses gh-5110
  • Loading branch information
candrews authored and jgrandja committed May 14, 2021
1 parent 362855b commit ecb4a57
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 56 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,6 +31,7 @@
*
* @author Joe Grandja
* @author Rob Winch
* @author Craig Andrews
* @since 5.0
* @see AuthorizationRequestRepository
* @see OAuth2AuthorizationRequest
Expand All @@ -41,6 +42,8 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au

private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME;

private boolean allowMultipleAuthorizationRequests;

@Override
public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
Assert.notNull(request, "request cannot be null");
Expand All @@ -63,9 +66,14 @@ public void saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationReq
}
String state = authorizationRequest.getState();
Assert.hasText(state, "authorizationRequest.state cannot be empty");
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
authorizationRequests.put(state, authorizationRequest);
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
if (this.allowMultipleAuthorizationRequests) {
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
authorizationRequests.put(state, authorizationRequest);
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
}
else {
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequest);
}
}

@Override
Expand All @@ -77,11 +85,16 @@ public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest
}
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
OAuth2AuthorizationRequest originalRequest = authorizationRequests.remove(stateParameter);
if (!authorizationRequests.isEmpty()) {
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
} else {
if (authorizationRequests.size() == 0) {
request.getSession().removeAttribute(this.sessionAttributeName);
}
else if (authorizationRequests.size() == 1) {
request.getSession().setAttribute(this.sessionAttributeName,
authorizationRequests.values().iterator().next());
}
else {
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
}
return originalRequest;
}

Expand All @@ -107,11 +120,38 @@ private String getStateParameter(HttpServletRequest request) {
*/
private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpServletRequest request) {
HttpSession session = request.getSession(false);
Map<String, OAuth2AuthorizationRequest> authorizationRequests = session == null ? null :
(Map<String, OAuth2AuthorizationRequest>) session.getAttribute(this.sessionAttributeName);
if (authorizationRequests == null) {
Object sessionAttributeValue = (session != null) ? session.getAttribute(this.sessionAttributeName) : null;
if (sessionAttributeValue == null) {
return new HashMap<>();
}
return authorizationRequests;
else if (sessionAttributeValue instanceof OAuth2AuthorizationRequest) {
OAuth2AuthorizationRequest auth2AuthorizationRequest = (OAuth2AuthorizationRequest) sessionAttributeValue;
Map<String, OAuth2AuthorizationRequest> authorizationRequests = new HashMap<>(1);
authorizationRequests.put(auth2AuthorizationRequest.getState(), auth2AuthorizationRequest);
return authorizationRequests;
}
else if (sessionAttributeValue instanceof Map) {
@SuppressWarnings("unchecked")
Map<String, OAuth2AuthorizationRequest> authorizationRequests = (Map<String, OAuth2AuthorizationRequest>) sessionAttributeValue;
return authorizationRequests;
}
else {
throw new IllegalStateException(
"authorizationRequests is supposed to be a Map or OAuth2AuthorizationRequest but actually is a "
+ sessionAttributeValue.getClass());
}
}

/**
* Configure if multiple {@link OAuth2AuthorizationRequest}s should be stored per
* session. Default is false (not allow multiple {@link OAuth2AuthorizationRequest}
* per session).
* @param allowMultipleAuthorizationRequests true allows more than one
* {@link OAuth2AuthorizationRequest} to be stored per session.
* @since 5.5
*/
@Deprecated
public void setAllowMultipleAuthorizationRequests(boolean allowMultipleAuthorizationRequests) {
this.allowMultipleAuthorizationRequests = allowMultipleAuthorizationRequests;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client.web;

import org.junit.Before;
import org.junit.Test;

import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository} when
* {@link HttpSessionOAuth2AuthorizationRequestRepository#setAllowMultipleAuthorizationRequests(boolean)}
* is enabled.
*
* @author Joe Grandja
* @author Craig Andrews
*/
public class HttpSessionOAuth2AuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests
extends HttpSessionOAuth2AuthorizationRequestRepositoryTests {

@Before
public void setup() {
this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
this.authorizationRequestRepository.setAllowMultipleAuthorizationRequests(true);
}

// gh-5110
@Test
public void loadAuthorizationRequestWhenMultipleSavedThenReturnMatchingAuthorizationRequest() {
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
String state1 = "state-1122";
OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1).build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);
String state2 = "state-3344";
OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2).build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);
String state3 = "state-5566";
OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3).build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response);
request.addParameter(OAuth2ParameterNames.STATE, state1);
OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository
.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest1).isEqualTo(authorizationRequest1);
request.removeParameter(OAuth2ParameterNames.STATE);
request.addParameter(OAuth2ParameterNames.STATE, state2);
OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository
.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2);
request.removeParameter(OAuth2ParameterNames.STATE);
request.addParameter(OAuth2ParameterNames.STATE, state3);
OAuth2AuthorizationRequest loadedAuthorizationRequest3 = this.authorizationRequestRepository
.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client.web;

import org.junit.Before;
import org.junit.Test;

import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository} when
* {@link HttpSessionOAuth2AuthorizationRequestRepository#setAllowMultipleAuthorizationRequests(boolean)}
* is disabled.
*
* @author Joe Grandja
* @author Craig Andrews
*/
public class HttpSessionOAuth2AuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests
extends HttpSessionOAuth2AuthorizationRequestRepositoryTests {

@Before
public void setup() {
this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
this.authorizationRequestRepository.setAllowMultipleAuthorizationRequests(false);
}

// gh-5145
@Test
public void loadAuthorizationRequestWhenMultipleSavedThenReturnLastAuthorizationRequest() {
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
String state1 = "state-1122";
OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1).build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);
String state2 = "state-3344";
OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2).build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);
String state3 = "state-5566";
OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3).build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response);
request.addParameter(OAuth2ParameterNames.STATE, state1);
OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository
.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest1).isNull();
request.removeParameter(OAuth2ParameterNames.STATE);
request.addParameter(OAuth2ParameterNames.STATE, state2);
OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository
.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest2).isNull();
request.removeParameter(OAuth2ParameterNames.STATE);
request.addParameter(OAuth2ParameterNames.STATE, state3);
OAuth2AuthorizationRequest loadedAuthorizationRequest3 = this.authorizationRequestRepository
.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,11 +34,12 @@
* Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}.
*
* @author Joe Grandja
* @author Craig Andrews
*/
@RunWith(MockitoJUnitRunner.class)
public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
private HttpSessionOAuth2AuthorizationRequestRepository authorizationRequestRepository =
new HttpSessionOAuth2AuthorizationRequestRepository();
public abstract class HttpSessionOAuth2AuthorizationRequestRepositoryTests {

protected HttpSessionOAuth2AuthorizationRequestRepository authorizationRequestRepository;

@Test(expected = IllegalArgumentException.class)
public void loadAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() {
Expand Down Expand Up @@ -70,42 +71,6 @@ public void loadAuthorizationRequestWhenSavedThenReturnAuthorizationRequest() {
assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest);
}

// gh-5110
@Test
public void loadAuthorizationRequestWhenMultipleSavedThenReturnMatchingAuthorizationRequest() {
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();

String state1 = "state-1122";
OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1).build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);

String state2 = "state-3344";
OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2).build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);

String state3 = "state-5566";
OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3).build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response);

request.addParameter(OAuth2ParameterNames.STATE, state1);
OAuth2AuthorizationRequest loadedAuthorizationRequest1 =
this.authorizationRequestRepository.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest1).isEqualTo(authorizationRequest1);

request.removeParameter(OAuth2ParameterNames.STATE);
request.addParameter(OAuth2ParameterNames.STATE, state2);
OAuth2AuthorizationRequest loadedAuthorizationRequest2 =
this.authorizationRequestRepository.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2);

request.removeParameter(OAuth2ParameterNames.STATE);
request.addParameter(OAuth2ParameterNames.STATE, state3);
OAuth2AuthorizationRequest loadedAuthorizationRequest3 =
this.authorizationRequestRepository.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3);
}

@Test
public void loadAuthorizationRequestWhenSavedAndStateParameterNullThenReturnNull() {
MockHttpServletRequest request = new MockHttpServletRequest();
Expand Down Expand Up @@ -284,11 +249,9 @@ public void removeAuthorizationRequestWhenNotSavedThenNotRemoved() {
assertThat(removedAuthorizationRequest).isNull();
}

private OAuth2AuthorizationRequest.Builder createAuthorizationRequest() {
return OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id-1234")
.state("state-1234");
protected OAuth2AuthorizationRequest.Builder createAuthorizationRequest() {
return OAuth2AuthorizationRequest.authorizationCode().authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id-1234").state("state-1234");
}

static class MockDistributedHttpSession extends MockHttpSession {
Expand Down

0 comments on commit ecb4a57

Please sign in to comment.