Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support passing groups in OAuth access token claim #10262

Merged
merged 11 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,15 @@

import javax.ws.rs.container.ContainerRequestContext;

import java.security.Principal;
import java.util.List;
import java.util.Optional;

import static com.google.common.net.HttpHeaders.AUTHORIZATION;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public abstract class AbstractBearerAuthenticator
implements Authenticator
{
private final UserMapping userMapping;

protected AbstractBearerAuthenticator(UserMapping userMapping)
{
this.userMapping = requireNonNull(userMapping, "userMapping is null");
}

@Override
public Identity authenticate(ContainerRequestContext request)
throws AuthenticationException
Expand All @@ -47,15 +38,7 @@ public Identity authenticate(ContainerRequestContext request, String token)
throws AuthenticationException
{
try {
Optional<Principal> principal = extractPrincipalFromToken(token);
if (principal.isEmpty()) {
throw needAuthentication(request, "Invalid credentials");
}

String authenticatedUser = userMapping.mapUser(principal.get().getName());
return Identity.forUser(authenticatedUser)
.withPrincipal(principal.get())
.build();
return createIdentity(token).orElseThrow(() -> needAuthentication(request, "Invalid credentials"));
}
catch (JwtException | UserMappingException e) {
throw needAuthentication(request, e.getMessage());
Expand Down Expand Up @@ -88,7 +71,8 @@ public String extractToken(ContainerRequestContext request)
return token;
}

protected abstract Optional<Principal> extractPrincipalFromToken(String token);
protected abstract Optional<Identity> createIdentity(String token)
throws UserMappingException;

protected abstract AuthenticationException needAuthentication(ContainerRequestContext request, String message);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
import io.jsonwebtoken.SigningKeyResolver;
import io.trino.server.security.AbstractBearerAuthenticator;
import io.trino.server.security.AuthenticationException;
import io.trino.server.security.UserMapping;
import io.trino.server.security.UserMappingException;
import io.trino.spi.security.BasicPrincipal;
import io.trino.spi.security.Identity;

import javax.inject.Inject;
import javax.ws.rs.container.ContainerRequestContext;

import java.security.Principal;
import java.util.Optional;

import static io.trino.server.security.UserMapping.createUserMapping;
Expand All @@ -34,11 +36,11 @@ public class JwtAuthenticator
{
private final JwtParser jwtParser;
private final String principalField;
private final UserMapping userMapping;

@Inject
public JwtAuthenticator(JwtAuthenticatorConfig config, SigningKeyResolver signingKeyResolver)
{
super(createUserMapping(config.getUserMappingPattern(), config.getUserMappingFile()));
principalField = config.getPrincipalField();

JwtParserBuilder jwtParser = Jwts.parserBuilder()
Expand All @@ -51,15 +53,22 @@ public JwtAuthenticator(JwtAuthenticatorConfig config, SigningKeyResolver signin
jwtParser.requireAudience(config.getRequiredAudience());
}
this.jwtParser = jwtParser.build();
userMapping = createUserMapping(config.getUserMappingPattern(), config.getUserMappingFile());
}

@Override
protected Optional<Principal> extractPrincipalFromToken(String token)
protected Optional<Identity> createIdentity(String token)
throws UserMappingException
{
return Optional.ofNullable(jwtParser.parseClaimsJws(token)
Optional<String> principal = Optional.ofNullable(jwtParser.parseClaimsJws(token)
.getBody()
.get(principalField, String.class))
.map(BasicPrincipal::new);
.get(principalField, String.class));
if (principal.isEmpty()) {
return Optional.empty();
}
return Optional.of(Identity.forUser(userMapping.mapUser(principal.get()))
.withPrincipal(new BasicPrincipal(principal.get()))
.build());
Comment on lines +66 to +71
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can continue the chain and avoid the if statement here by using map.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid it would not work. User mapping throws UserMappingException which is checked exception.

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@
*/
package io.trino.server.security.oauth2;

import com.google.common.collect.ImmutableSet;
import io.trino.server.security.AbstractBearerAuthenticator;
import io.trino.server.security.AuthenticationException;
import io.trino.server.security.UserMapping;
import io.trino.server.security.UserMappingException;
import io.trino.spi.security.BasicPrincipal;
import io.trino.spi.security.Identity;

import javax.inject.Inject;
import javax.ws.rs.container.ContainerRequestContext;

import java.net.URI;
import java.security.Principal;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;

Expand All @@ -36,23 +41,33 @@ public class OAuth2Authenticator
{
private final OAuth2Service service;
private final String principalField;
private final Optional<String> groupsField;
private final UserMapping userMapping;

@Inject
public OAuth2Authenticator(OAuth2Service service, OAuth2Config config)
{
super(createUserMapping(config.getUserMappingPattern(), config.getUserMappingFile()));
this.service = requireNonNull(service, "service is null");
this.principalField = config.getPrincipalField();
groupsField = requireNonNull(config.getGroupsField(), "groupsField is null");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can skip this rnn here and in other places as groups field is an empty Optional by default in the config.

userMapping = createUserMapping(config.getUserMappingPattern(), config.getUserMappingFile());
}

@Override
protected Optional<Principal> extractPrincipalFromToken(String token)
protected Optional<Identity> createIdentity(String token)
throws UserMappingException
{
try {
return service.convertTokenToClaims(token)
.map(claims -> claims.get(principalField))
.map(String.class::cast)
.map(BasicPrincipal::new);
Optional<Map<String, Object>> claims = service.convertTokenToClaims(token);
if (claims.isEmpty()) {
return Optional.empty();
}
String principal = (String) claims.get().get(principalField);
Identity.Builder builder = Identity.forUser(userMapping.mapUser(principal));
builder.withPrincipal(new BasicPrincipal(principal));
groupsField.flatMap(field -> Optional.ofNullable((List<String>) claims.get().get(field)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for flatMap as map handles null correctly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional.map(o -> Optional.of(o)) returns Optional<Optional<>>
whereas
Optional.flatMap(o -> Optional.of(o)) returns Optional<>

.ifPresent(groups -> builder.withGroups(ImmutableSet.copyOf(groups)));
return Optional.of(builder.build());
}
catch (ChallengeFailedException e) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public class OAuth2Config
private String clientSecret;
private Set<String> scopes = ImmutableSet.of(OPENID_SCOPE);
private String principalField = "sub";
private Optional<String> groupsField = Optional.empty();
private List<String> additionalAudiences = Collections.emptyList();
private Duration challengeTimeout = new Duration(15, TimeUnit.MINUTES);
private Optional<String> userMappingPattern = Optional.empty();
Expand Down Expand Up @@ -222,6 +223,19 @@ public OAuth2Config setPrincipalField(String principalField)
return this;
}

public Optional<String> getGroupsField()
{
return groupsField;
}

@Config("http-server.authentication.oauth2.groups-field")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add changes to docs too

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's already done:

   * - ``http-server.authentication.oauth2.groups-field``
     - The field of the access token used for Trino groups. The corresponding claim value must be an array.

Would like to add something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Thanks.

@ConfigDescription("Groups field in the claim")
public OAuth2Config setGroupsField(String groupsField)
{
this.groupsField = Optional.ofNullable(groupsField);
return this;
}

@MinDuration("1ms")
@NotNull
public Duration getChallengeTimeout()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.core.Response;

import java.util.List;
import java.util.Map;
import java.util.Optional;

Expand All @@ -50,6 +51,7 @@ public class OAuth2WebUiAuthenticationFilter
private final String principalField;
private final OAuth2Service service;
private final UserMapping userMapping;
private final Optional<String> groupsField;

@Inject
public OAuth2WebUiAuthenticationFilter(OAuth2Service service, OAuth2Config oauth2Config)
Expand All @@ -58,6 +60,7 @@ public OAuth2WebUiAuthenticationFilter(OAuth2Service service, OAuth2Config oauth
requireNonNull(oauth2Config, "oauth2Config is null");
this.userMapping = UserMapping.createUserMapping(oauth2Config.getUserMappingPattern(), oauth2Config.getUserMappingFile());
this.principalField = oauth2Config.getPrincipalField();
groupsField = requireNonNull(oauth2Config.getGroupsField(), "groupsField is null");
}

@Override
Expand Down Expand Up @@ -101,9 +104,11 @@ public void filter(ContainerRequestContext request)
return;
}
String principalName = (String) principal;
setAuthenticatedIdentity(request, Identity.forUser(userMapping.mapUser(principalName))
.withPrincipal(new BasicPrincipal(principalName))
.build());
Identity.Builder builder = Identity.forUser(userMapping.mapUser(principalName));
builder.withPrincipal(new BasicPrincipal(principalName));
groupsField.flatMap(field -> Optional.ofNullable((List<String>) claims.get().get(field)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again map here is cleaner and clearer.

Copy link
Member Author

@lukasz-walkiewicz lukasz-walkiewicz Dec 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

userMapping.mapUser(principalName) throws checked UserMappingException

.ifPresent(groups -> builder.withGroups(ImmutableSet.copyOf(groups)));
setAuthenticatedIdentity(request, builder.build());
}
catch (UserMappingException e) {
sendErrorMessage(request, UNAUTHORIZED, firstNonNull(e.getMessage(), "Unauthorized"));
Expand Down
Loading