Skip to content

Commit

Permalink
Expose user name attribute name in OAuth2UserAuthority
Browse files Browse the repository at this point in the history
  • Loading branch information
filiphr authored and sjohnr committed Jun 3, 2024
1 parent b41ec0a commit 99aee99
Show file tree
Hide file tree
Showing 13 changed files with 92 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,18 @@ static boolean shouldRetrieveUserInfo(OidcUserRequest userRequest) {

static OidcUser getUser(OidcUserRequest userRequest, OidcUserInfo userInfo) {
Set<GrantedAuthority> authorities = new LinkedHashSet<>();
authorities.add(new OidcUserAuthority(userRequest.getIdToken(), userInfo));
ClientRegistration.ProviderDetails providerDetails = userRequest.getClientRegistration().getProviderDetails();
String userNameAttributeName = providerDetails.getUserInfoEndpoint().getUserNameAttributeName();
if (StringUtils.hasLength(userNameAttributeName)) {
authorities.add(new OidcUserAuthority(userRequest.getIdToken(), userInfo, userNameAttributeName));
}
else {
authorities.add(new OidcUserAuthority(userRequest.getIdToken(), userInfo));
}
OAuth2AccessToken token = userRequest.getAccessToken();
for (String scope : token.getScopes()) {
authorities.add(new SimpleGrantedAuthority("SCOPE_" + scope));
}
ClientRegistration.ProviderDetails providerDetails = userRequest.getClientRegistration().getProviderDetails();
String userNameAttributeName = providerDetails.getUserInfoEndpoint().getUserNameAttributeName();
if (StringUtils.hasText(userNameAttributeName)) {
return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo, userNameAttributeName);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2Authentic
ResponseEntity<Map<String, Object>> response = getResponse(userRequest, request);
OAuth2AccessToken token = userRequest.getAccessToken();
Map<String, Object> attributes = this.attributesConverter.convert(userRequest).convert(response.getBody());
Collection<GrantedAuthority> authorities = getAuthorities(token, attributes);
Collection<GrantedAuthority> authorities = getAuthorities(token, attributes, userNameAttributeName);
return new DefaultOAuth2User(authorities, attributes, userNameAttributeName);
}

Expand Down Expand Up @@ -187,9 +187,10 @@ private String getUserNameAttributeName(OAuth2UserRequest userRequest) {
return userNameAttributeName;
}

private Collection<GrantedAuthority> getAuthorities(OAuth2AccessToken token, Map<String, Object> attributes) {
private Collection<GrantedAuthority> getAuthorities(OAuth2AccessToken token, Map<String, Object> attributes,
String userNameAttributeName) {
Collection<GrantedAuthority> authorities = new LinkedHashSet<>();
authorities.add(new OAuth2UserAuthority(attributes));
authorities.add(new OAuth2UserAuthority(attributes, userNameAttributeName));
for (String authority : token.getScopes()) {
authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public Mono<OAuth2User> loadUser(OAuth2UserRequest userRequest) throws OAuth2Aut
.bodyToMono(DefaultReactiveOAuth2UserService.STRING_OBJECT_MAP)
.mapNotNull((attributes) -> this.attributesConverter.convert(userRequest).convert(attributes));
return userAttributes.map((attrs) -> {
GrantedAuthority authority = new OAuth2UserAuthority(attrs);
GrantedAuthority authority = new OAuth2UserAuthority(attrs, userNameAttributeName);
Set<GrantedAuthority> authorities = new HashSet<>();
authorities.add(authority);
OAuth2AccessToken token = userRequest.getAccessToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ private static String asJson(OAuth2UserAuthority oauth2UserAuthority) {
return "{\n" +
" \"@class\": \"org.springframework.security.oauth2.core.user.OAuth2UserAuthority\",\n" +
" \"authority\": \"" + oauth2UserAuthority.getAuthority() + "\",\n" +
" \"userNameAttributeName\": \"username\",\n" +
" \"attributes\": {\n" +
" \"@class\": \"java.util.Collections$UnmodifiableMap\",\n" +
" \"username\": \"user\"\n" +
Expand All @@ -260,6 +261,7 @@ private static String asJson(OidcUserAuthority oidcUserAuthority) {
return "{\n" +
" \"@class\": \"org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority\",\n" +
" \"authority\": \"" + oidcUserAuthority.getAuthority() + "\",\n" +
" \"userNameAttributeName\": \"" + oidcUserAuthority.getUserNameAttributeName() + "\",\n" +
" \"idToken\": " + asJson(oidcUserAuthority.getIdToken()) + ",\n" +
" \"userInfo\": " + asJson(oidcUserAuthority.getUserInfo()) + "\n" +
" }";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() {
OAuth2UserAuthority userAuthority = (OAuth2UserAuthority) user.getAuthorities().iterator().next();
assertThat(userAuthority.getAuthority()).isEqualTo("OIDC_USER");
assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes());
assertThat(userAuthority.getUserNameAttributeName()).isEqualTo("id");
}

@Test
Expand Down Expand Up @@ -361,6 +362,7 @@ public void loadUserWhenNestedUserInfoSuccessThenReturnUser() throws IOException
OAuth2UserAuthority userAuthority = (OAuth2UserAuthority) user.getAuthorities().iterator().next();
assertThat(userAuthority.getAuthority()).isEqualTo("OIDC_USER");
assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes());
assertThat(userAuthority.getUserNameAttributeName()).isEqualTo("user-name");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,7 @@ public void loadUserWhenNestedUserInfoSuccessThenReturnUser() {
OAuth2UserAuthority userAuthority = (OAuth2UserAuthority) user.getAuthorities().iterator().next();
assertThat(userAuthority.getAuthority()).isEqualTo("OIDC_USER");
assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes());
assertThat(userAuthority.getUserNameAttributeName()).isEqualTo("user-name");
}

private MockResponse jsonResponse(String json) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
OAuth2UserAuthority userAuthority = (OAuth2UserAuthority) user.getAuthorities().iterator().next();
assertThat(userAuthority.getAuthority()).isEqualTo("OAUTH2_USER");
assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes());
assertThat(userAuthority.getUserNameAttributeName()).isEqualTo("user-name");
}

@Test
Expand Down Expand Up @@ -196,6 +197,7 @@ public void loadUserWhenNestedUserInfoSuccessThenReturnUser() {
OAuth2UserAuthority userAuthority = (OAuth2UserAuthority) user.getAuthorities().iterator().next();
assertThat(userAuthority.getAuthority()).isEqualTo("OAUTH2_USER");
assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes());
assertThat(userAuthority.getUserNameAttributeName()).isEqualTo("user-name");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
OAuth2UserAuthority userAuthority = (OAuth2UserAuthority) user.getAuthorities().iterator().next();
assertThat(userAuthority.getAuthority()).isEqualTo("OAUTH2_USER");
assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes());
assertThat(userAuthority.getUserNameAttributeName()).isEqualTo("id");
}

// gh-9336
Expand Down Expand Up @@ -203,6 +204,7 @@ public void loadUserWhenNestedUserInfoSuccessThenReturnUser() {
OAuth2UserAuthority userAuthority = (OAuth2UserAuthority) user.getAuthorities().iterator().next();
assertThat(userAuthority.getAuthority()).isEqualTo("OAUTH2_USER");
assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes());
assertThat(userAuthority.getUserNameAttributeName()).isEqualTo("user-name");
}

// gh-5500
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Map;

import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
Expand Down Expand Up @@ -57,6 +58,19 @@ public OidcUserAuthority(OidcIdToken idToken, OidcUserInfo userInfo) {
this("OIDC_USER", idToken, userInfo);
}

/**
* Constructs a {@code OidcUserAuthority} using the provided parameters and defaults
* {@link #getAuthority()} to {@code OIDC_USER}.
* @param idToken the {@link OidcIdToken ID Token} containing claims about the user
* @param userInfo the {@link OidcUserInfo UserInfo} containing claims about the user,
* may be {@code null}
* @param userNameAttributeName the attribute name used to access the user's name from
* the attributes
*/
public OidcUserAuthority(OidcIdToken idToken, OidcUserInfo userInfo, String userNameAttributeName) {
this("OIDC_USER", idToken, userInfo, userNameAttributeName);
}

/**
* Constructs a {@code OidcUserAuthority} using the provided parameters.
* @param authority the authority granted to the user
Expand All @@ -65,7 +79,21 @@ public OidcUserAuthority(OidcIdToken idToken, OidcUserInfo userInfo) {
* may be {@code null}
*/
public OidcUserAuthority(String authority, OidcIdToken idToken, OidcUserInfo userInfo) {
super(authority, collectClaims(idToken, userInfo));
this(authority, idToken, userInfo, IdTokenClaimNames.SUB);
}

/**
* Constructs a {@code OidcUserAuthority} using the provided parameters.
* @param authority the authority granted to the user
* @param idToken the {@link OidcIdToken ID Token} containing claims about the user
* @param userInfo the {@link OidcUserInfo UserInfo} containing claims about the user,
* may be {@code null}
* @param userNameAttributeName the attribute name used to access the user's name from
* the attributes
*/
public OidcUserAuthority(String authority, OidcIdToken idToken, OidcUserInfo userInfo,
String userNameAttributeName) {
super(authority, collectClaims(idToken, userInfo), userNameAttributeName);
this.idToken = idToken;
this.userInfo = userInfo;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Map;
import java.util.Objects;

import org.springframework.lang.Nullable;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.SpringSecurityCoreVersion;
import org.springframework.util.Assert;
Expand All @@ -41,6 +42,8 @@ public class OAuth2UserAuthority implements GrantedAuthority {

private final Map<String, Object> attributes;

private final String userNameAttributeName;

/**
* Constructs a {@code OAuth2UserAuthority} using the provided parameters and defaults
* {@link #getAuthority()} to {@code OAUTH2_USER}.
Expand All @@ -50,16 +53,39 @@ public OAuth2UserAuthority(Map<String, Object> attributes) {
this("OAUTH2_USER", attributes);
}

/**
* Constructs a {@code OAuth2UserAuthority} using the provided parameters and defaults
* {@link #getAuthority()} to {@code OAUTH2_USER}.
* @param attributes the attributes about the user
* @param userNameAttributeName the attribute name used to access the user's name from
* the attributes
*/
public OAuth2UserAuthority(Map<String, Object> attributes, @Nullable String userNameAttributeName) {
this("OAUTH2_USER", attributes, userNameAttributeName);
}

/**
* Constructs a {@code OAuth2UserAuthority} using the provided parameters.
* @param authority the authority granted to the user
* @param attributes the attributes about the user
*/
public OAuth2UserAuthority(String authority, Map<String, Object> attributes) {
this(authority, attributes, null);
}

/**
* Constructs a {@code OAuth2UserAuthority} using the provided parameters.
* @param authority the authority granted to the user
* @param attributes the attributes about the user
* @param userNameAttributeName the attribute name used to access the user's name from
* the attributes
*/
public OAuth2UserAuthority(String authority, Map<String, Object> attributes, String userNameAttributeName) {
Assert.hasText(authority, "authority cannot be empty");
Assert.notEmpty(attributes, "attributes cannot be empty");
this.authority = authority;
this.attributes = Collections.unmodifiableMap(new LinkedHashMap<>(attributes));
this.userNameAttributeName = userNameAttributeName;
}

@Override
Expand All @@ -75,6 +101,15 @@ public Map<String, Object> getAttributes() {
return this.attributes;
}

/**
* Returns the attribute name used to access the user's name from the attributes.
* @return the attribute name used to access the user's name from the attributes
*/
@Nullable
public String getUserNameAttributeName() {
return this.userNameAttributeName;
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ public static DefaultOAuth2User create() {
String nameAttributeKey = "username";
Map<String, Object> attributes = new HashMap<>();
attributes.put(nameAttributeKey, "user");
Collection<GrantedAuthority> authorities = authorities(attributes);
Collection<GrantedAuthority> authorities = authorities(attributes, nameAttributeKey);
return new DefaultOAuth2User(authorities, attributes, nameAttributeKey);
}

private static Collection<GrantedAuthority> authorities(Map<String, Object> attributes) {
return new LinkedHashSet<>(Arrays.asList(new OAuth2UserAuthority(attributes),
private static Collection<GrantedAuthority> authorities(Map<String, Object> attributes,
String userNameAttributeName) {
return new LinkedHashSet<>(Arrays.asList(new OAuth2UserAuthority(attributes, userNameAttributeName),
new SimpleGrantedAuthority("SCOPE_read"), new SimpleGrantedAuthority("SCOPE_write")));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ private ClientRegistration.Builder clientRegistrationBuilder() {

private Collection<GrantedAuthority> defaultAuthorities() {
Set<GrantedAuthority> authorities = new LinkedHashSet<>();
authorities.add(new OAuth2UserAuthority(this.attributes.get()));
authorities.add(new OAuth2UserAuthority(this.attributes.get(), this.nameAttributeKey));
for (String authority : this.accessToken.getScopes()) {
authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1376,7 +1376,7 @@ private ClientRegistration.Builder clientRegistrationBuilder() {

private Collection<GrantedAuthority> defaultAuthorities() {
Set<GrantedAuthority> authorities = new LinkedHashSet<>();
authorities.add(new OAuth2UserAuthority(this.attributes.get()));
authorities.add(new OAuth2UserAuthority(this.attributes.get(), this.nameAttributeKey));
for (String authority : this.accessToken.getScopes()) {
authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
}
Expand Down

0 comments on commit 99aee99

Please sign in to comment.