Skip to content

Commit

Permalink
HttpSessionOAuth2AuthorizationRequestRepository handle multiple OAuth…
Browse files Browse the repository at this point in the history
…2AuthorizationRequest per session

Fixes spring-projectsgh-5110
  • Loading branch information
jgrandja committed Mar 19, 2018
1 parent 1851aaa commit ee57e71
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,11 +16,14 @@
package org.springframework.security.oauth2.client.web;

import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.util.HashMap;
import java.util.Map;

/**
* An implementation of an {@link AuthorizationRequestRepository} that stores
Expand All @@ -39,9 +42,10 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
@Override
public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
Assert.notNull(request, "request cannot be null");
HttpSession session = request.getSession(false);
if (session != null) {
return (OAuth2AuthorizationRequest) session.getAttribute(this.sessionAttributeName);
Assert.hasText(request.getParameter(OAuth2ParameterNames.STATE), "state parameter cannot be empty");
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
if (authorizationRequests != null) {
return authorizationRequests.get(request.getParameter(OAuth2ParameterNames.STATE));
}
return null;
}
Expand All @@ -55,16 +59,36 @@ public void saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationReq
this.removeAuthorizationRequest(request);
return;
}
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequest);
Assert.hasText(authorizationRequest.getState(), "authorizationRequest.state cannot be empty");
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request, true);
authorizationRequests.put(authorizationRequest.getState(), authorizationRequest);
}

@Override
public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest request) {
Assert.notNull(request, "request cannot be null");
OAuth2AuthorizationRequest authorizationRequest = this.loadAuthorizationRequest(request);
if (authorizationRequest != null) {
request.getSession().removeAttribute(this.sessionAttributeName);
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
authorizationRequests.remove(authorizationRequest.getState());
}
return authorizationRequest;
}

private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpServletRequest request) {
return this.getAuthorizationRequests(request, false);
}

private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpServletRequest request, boolean createSession) {
Map<String, OAuth2AuthorizationRequest> authorizationRequests = null;
HttpSession session = request.getSession(createSession);
if (session != null) {
authorizationRequests = (Map<String, OAuth2AuthorizationRequest>) session.getAttribute(this.sessionAttributeName);
if (authorizationRequests == null) {
authorizationRequests = new HashMap<>();
session.setAttribute(this.sessionAttributeName, authorizationRequests);
}
}
return authorizationRequests;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

/**
* Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}.
Expand All @@ -44,8 +46,10 @@ public void loadAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegal

@Test
public void loadAuthorizationRequestWhenNotSavedThenReturnNull() {
MockHttpServletRequest request = new MockHttpServletRequest();
request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
OAuth2AuthorizationRequest authorizationRequest =
this.authorizationRequestRepository.loadAuthorizationRequest(new MockHttpServletRequest());
this.authorizationRequestRepository.loadAuthorizationRequest(request);

assertThat(authorizationRequest).isNull();
}
Expand All @@ -54,15 +58,69 @@ public void loadAuthorizationRequestWhenNotSavedThenReturnNull() {
public void loadAuthorizationRequestWhenSavedThenReturnAuthorizationRequest() {
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();

OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
when(authorizationRequest.getState()).thenReturn("state-1234");

this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
OAuth2AuthorizationRequest loadedAuthorizationRequest =
this.authorizationRequestRepository.loadAuthorizationRequest(request);

assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest);
}

// gh-5110
@Test
public void loadAuthorizationRequestWhenMultipleSavedThenReturnMatchingAuthorizationRequest() {
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();

String state1 = "state-1122";
OAuth2AuthorizationRequest authorizationRequest1 = mock(OAuth2AuthorizationRequest.class);
when(authorizationRequest1.getState()).thenReturn(state1);
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);

String state2 = "state-3344";
OAuth2AuthorizationRequest authorizationRequest2 = mock(OAuth2AuthorizationRequest.class);
when(authorizationRequest2.getState()).thenReturn(state2);
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);

String state3 = "state-5566";
OAuth2AuthorizationRequest authorizationRequest3 = mock(OAuth2AuthorizationRequest.class);
when(authorizationRequest3.getState()).thenReturn(state3);
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response);

request.addParameter(OAuth2ParameterNames.STATE, state1);
OAuth2AuthorizationRequest loadedAuthorizationRequest1 =
this.authorizationRequestRepository.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest1).isEqualTo(authorizationRequest1);

request.removeParameter(OAuth2ParameterNames.STATE);
request.addParameter(OAuth2ParameterNames.STATE, state2);
OAuth2AuthorizationRequest loadedAuthorizationRequest2 =
this.authorizationRequestRepository.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2);

request.removeParameter(OAuth2ParameterNames.STATE);
request.addParameter(OAuth2ParameterNames.STATE, state3);
OAuth2AuthorizationRequest loadedAuthorizationRequest3 =
this.authorizationRequestRepository.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3);
}

@Test(expected = IllegalArgumentException.class)
public void loadAuthorizationRequestWhenSavedAndStateParameterNullThenThrowIllegalArgumentException() {
MockHttpServletRequest request = new MockHttpServletRequest();

OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
when(authorizationRequest.getState()).thenReturn("state-1234");
this.authorizationRequestRepository.saveAuthorizationRequest(
authorizationRequest, request, new MockHttpServletResponse());

this.authorizationRequestRepository.loadAuthorizationRequest(request);
}

@Test(expected = IllegalArgumentException.class)
public void saveAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() {
this.authorizationRequestRepository.saveAuthorizationRequest(
Expand All @@ -75,13 +133,22 @@ public void saveAuthorizationRequestWhenHttpServletResponseIsNullThenThrowIllega
mock(OAuth2AuthorizationRequest.class), new MockHttpServletRequest(), null);
}

@Test(expected = IllegalArgumentException.class)
public void saveAuthorizationRequestWhenStateNullThenThrowIllegalArgumentException() {
this.authorizationRequestRepository.saveAuthorizationRequest(
mock(OAuth2AuthorizationRequest.class), new MockHttpServletRequest(), new MockHttpServletResponse());
}

@Test
public void saveAuthorizationRequestWhenNotNullThenSaved() {
MockHttpServletRequest request = new MockHttpServletRequest();
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);

OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
when(authorizationRequest.getState()).thenReturn("state-1234");
this.authorizationRequestRepository.saveAuthorizationRequest(
authorizationRequest, request, new MockHttpServletResponse());

request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
OAuth2AuthorizationRequest loadedAuthorizationRequest =
this.authorizationRequestRepository.loadAuthorizationRequest(request);

Expand All @@ -92,12 +159,17 @@ public void saveAuthorizationRequestWhenNotNullThenSaved() {
public void saveAuthorizationRequestWhenNullThenRemoved() {
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();

OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
when(authorizationRequest.getState()).thenReturn("state-1234");

this.authorizationRequestRepository.saveAuthorizationRequest( // Save
authorizationRequest, request, response);

request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
this.authorizationRequestRepository.saveAuthorizationRequest( // Null value removes
null, request, response);

OAuth2AuthorizationRequest loadedAuthorizationRequest =
this.authorizationRequestRepository.loadAuthorizationRequest(request);

Expand All @@ -113,10 +185,14 @@ public void removeAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIlleg
public void removeAuthorizationRequestWhenSavedThenRemoved() {
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();

OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
when(authorizationRequest.getState()).thenReturn("state-1234");

this.authorizationRequestRepository.saveAuthorizationRequest(
authorizationRequest, request, response);

request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
OAuth2AuthorizationRequest removedAuthorizationRequest =
this.authorizationRequestRepository.removeAuthorizationRequest(request);
OAuth2AuthorizationRequest loadedAuthorizationRequest =
Expand All @@ -129,6 +205,7 @@ public void removeAuthorizationRequestWhenSavedThenRemoved() {
@Test
public void removeAuthorizationRequestWhenNotSavedThenNotRemoved() {
MockHttpServletRequest request = new MockHttpServletRequest();
request.addParameter(OAuth2ParameterNames.STATE, "state-1234");

OAuth2AuthorizationRequest removedAuthorizationRequest =
this.authorizationRequestRepository.removeAuthorizationRequest(request);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
Expand All @@ -26,8 +27,6 @@
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;

import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
Expand Down Expand Up @@ -153,7 +152,7 @@ public void doFilterWhenAuthorizationRequestAuthorizationCodeGrantThenRedirectFo
}

@Test
public void doFilterWhenAuthorizationRequestAuthorizationCodeGrantThenAuthorizationRequestSavedInSession() throws Exception {
public void doFilterWhenAuthorizationRequestAuthorizationCodeGrantThenAuthorizationRequestSaved() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
"/" + this.registration2.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
Expand All @@ -162,31 +161,14 @@ public void doFilterWhenAuthorizationRequestAuthorizationCodeGrantThenAuthorizat
FilterChain filterChain = mock(FilterChain.class);

AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
new HttpSessionOAuth2AuthorizationRequestRepository();
mock(AuthorizationRequestRepository.class);
this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);

this.filter.doFilter(request, response, filterChain);

verifyZeroInteractions(filterChain);

OAuth2AuthorizationRequest authorizationRequest = authorizationRequestRepository.loadAuthorizationRequest(request);

assertThat(authorizationRequest).isNotNull();
assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo(
this.registration2.getProviderDetails().getAuthorizationUri());
assertThat(authorizationRequest.getGrantType()).isEqualTo(
this.registration2.getAuthorizationGrantType());
assertThat(authorizationRequest.getResponseType()).isEqualTo(
OAuth2AuthorizationResponseType.CODE);
assertThat(authorizationRequest.getClientId()).isEqualTo(
this.registration2.getClientId());
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
"http://localhost/login/oauth2/code/registration-2");
assertThat(authorizationRequest.getScopes()).isEqualTo(
this.registration2.getScopes());
assertThat(authorizationRequest.getState()).isNotNull();
assertThat(authorizationRequest.getAdditionalParameters()
.get(OAuth2ParameterNames.REGISTRATION_ID)).isEqualTo(this.registration2.getRegistrationId());
verify(authorizationRequestRepository).saveAuthorizationRequest(
any(OAuth2AuthorizationRequest.class), any(HttpServletRequest.class), any(HttpServletResponse.class));
}

@Test
Expand All @@ -206,7 +188,7 @@ public void doFilterWhenAuthorizationRequestImplicitGrantThenRedirectForAuthoriz
}

@Test
public void doFilterWhenAuthorizationRequestImplicitGrantThenAuthorizationRequestNotSavedInSession() throws Exception {
public void doFilterWhenAuthorizationRequestImplicitGrantThenAuthorizationRequestNotSaved() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
"/" + this.registration3.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
Expand All @@ -215,16 +197,14 @@ public void doFilterWhenAuthorizationRequestImplicitGrantThenAuthorizationReques
FilterChain filterChain = mock(FilterChain.class);

AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
new HttpSessionOAuth2AuthorizationRequestRepository();
mock(AuthorizationRequestRepository.class);
this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);

this.filter.doFilter(request, response, filterChain);

verifyZeroInteractions(filterChain);

OAuth2AuthorizationRequest authorizationRequest = authorizationRequestRepository.loadAuthorizationRequest(request);

assertThat(authorizationRequest).isNull();
verify(authorizationRequestRepository, times(0)).saveAuthorizationRequest(
any(OAuth2AuthorizationRequest.class), any(HttpServletRequest.class), any(HttpServletResponse.class));
}

@Test
Expand Down Expand Up @@ -255,14 +235,19 @@ public void doFilterWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriE
FilterChain filterChain = mock(FilterChain.class);

AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
new HttpSessionOAuth2AuthorizationRequestRepository();
mock(AuthorizationRequestRepository.class);
this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);

this.filter.doFilter(request, response, filterChain);

ArgumentCaptor<OAuth2AuthorizationRequest> authorizationRequestArgCaptor =
ArgumentCaptor.forClass(OAuth2AuthorizationRequest.class);

verifyZeroInteractions(filterChain);
verify(authorizationRequestRepository).saveAuthorizationRequest(
authorizationRequestArgCaptor.capture(), any(HttpServletRequest.class), any(HttpServletResponse.class));

OAuth2AuthorizationRequest authorizationRequest = authorizationRequestRepository.loadAuthorizationRequest(request);
OAuth2AuthorizationRequest authorizationRequest = authorizationRequestArgCaptor.getValue();

assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(
this.registration2.getRedirectUriTemplate());
Expand Down
Loading

0 comments on commit ee57e71

Please sign in to comment.