From 3d64dfb460fc4a78d7bcf798551190b085a81b00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Walkiewicz?= Date: Tue, 7 Dec 2021 09:39:58 +0100 Subject: [PATCH] Extract identity building to bearer authenticators --- .../security/AbstractBearerAuthenticator.java | 22 +++--------------- .../server/security/jwt/JwtAuthenticator.java | 21 ++++++++++++----- .../security/oauth2/OAuth2Authenticator.java | 23 +++++++++++++------ 3 files changed, 34 insertions(+), 32 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/server/security/AbstractBearerAuthenticator.java b/core/trino-main/src/main/java/io/trino/server/security/AbstractBearerAuthenticator.java index dfcc30519450..611099332534 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/AbstractBearerAuthenticator.java +++ b/core/trino-main/src/main/java/io/trino/server/security/AbstractBearerAuthenticator.java @@ -18,24 +18,15 @@ import javax.ws.rs.container.ContainerRequestContext; -import java.security.Principal; import java.util.List; import java.util.Optional; import static com.google.common.net.HttpHeaders.AUTHORIZATION; import static java.lang.String.format; -import static java.util.Objects.requireNonNull; public abstract class AbstractBearerAuthenticator implements Authenticator { - private final UserMapping userMapping; - - protected AbstractBearerAuthenticator(UserMapping userMapping) - { - this.userMapping = requireNonNull(userMapping, "userMapping is null"); - } - @Override public Identity authenticate(ContainerRequestContext request) throws AuthenticationException @@ -47,15 +38,7 @@ public Identity authenticate(ContainerRequestContext request, String token) throws AuthenticationException { try { - Optional principal = extractPrincipalFromToken(token); - if (principal.isEmpty()) { - throw needAuthentication(request, "Invalid credentials"); - } - - String authenticatedUser = userMapping.mapUser(principal.get().getName()); - return Identity.forUser(authenticatedUser) - .withPrincipal(principal.get()) - .build(); + return createIdentity(token).orElseThrow(() -> needAuthentication(request, "Invalid credentials")); } catch (JwtException | UserMappingException e) { throw needAuthentication(request, e.getMessage()); @@ -88,7 +71,8 @@ public String extractToken(ContainerRequestContext request) return token; } - protected abstract Optional extractPrincipalFromToken(String token); + protected abstract Optional createIdentity(String token) + throws UserMappingException; protected abstract AuthenticationException needAuthentication(ContainerRequestContext request, String message); } diff --git a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java index 1fb38a9b348b..5989470f0384 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java +++ b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java @@ -19,12 +19,14 @@ import io.jsonwebtoken.SigningKeyResolver; import io.trino.server.security.AbstractBearerAuthenticator; import io.trino.server.security.AuthenticationException; +import io.trino.server.security.UserMapping; +import io.trino.server.security.UserMappingException; import io.trino.spi.security.BasicPrincipal; +import io.trino.spi.security.Identity; import javax.inject.Inject; import javax.ws.rs.container.ContainerRequestContext; -import java.security.Principal; import java.util.Optional; import static io.trino.server.security.UserMapping.createUserMapping; @@ -34,11 +36,11 @@ public class JwtAuthenticator { private final JwtParser jwtParser; private final String principalField; + private final UserMapping userMapping; @Inject public JwtAuthenticator(JwtAuthenticatorConfig config, SigningKeyResolver signingKeyResolver) { - super(createUserMapping(config.getUserMappingPattern(), config.getUserMappingFile())); principalField = config.getPrincipalField(); JwtParserBuilder jwtParser = Jwts.parserBuilder() @@ -51,15 +53,22 @@ public JwtAuthenticator(JwtAuthenticatorConfig config, SigningKeyResolver signin jwtParser.requireAudience(config.getRequiredAudience()); } this.jwtParser = jwtParser.build(); + userMapping = createUserMapping(config.getUserMappingPattern(), config.getUserMappingFile()); } @Override - protected Optional extractPrincipalFromToken(String token) + protected Optional createIdentity(String token) + throws UserMappingException { - return Optional.ofNullable(jwtParser.parseClaimsJws(token) + Optional principal = Optional.ofNullable(jwtParser.parseClaimsJws(token) .getBody() - .get(principalField, String.class)) - .map(BasicPrincipal::new); + .get(principalField, String.class)); + if (principal.isEmpty()) { + return Optional.empty(); + } + return Optional.of(Identity.forUser(userMapping.mapUser(principal.get())) + .withPrincipal(new BasicPrincipal(principal.get())) + .build()); } @Override diff --git a/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2Authenticator.java b/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2Authenticator.java index 0482b9393099..2d076912b729 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2Authenticator.java +++ b/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2Authenticator.java @@ -15,13 +15,16 @@ import io.trino.server.security.AbstractBearerAuthenticator; import io.trino.server.security.AuthenticationException; +import io.trino.server.security.UserMapping; +import io.trino.server.security.UserMappingException; import io.trino.spi.security.BasicPrincipal; +import io.trino.spi.security.Identity; import javax.inject.Inject; import javax.ws.rs.container.ContainerRequestContext; import java.net.URI; -import java.security.Principal; +import java.util.Map; import java.util.Optional; import java.util.UUID; @@ -36,23 +39,29 @@ public class OAuth2Authenticator { private final OAuth2Service service; private final String principalField; + private final UserMapping userMapping; @Inject public OAuth2Authenticator(OAuth2Service service, OAuth2Config config) { - super(createUserMapping(config.getUserMappingPattern(), config.getUserMappingFile())); this.service = requireNonNull(service, "service is null"); this.principalField = config.getPrincipalField(); + userMapping = createUserMapping(config.getUserMappingPattern(), config.getUserMappingFile()); } @Override - protected Optional extractPrincipalFromToken(String token) + protected Optional createIdentity(String token) + throws UserMappingException { try { - return service.convertTokenToClaims(token) - .map(claims -> claims.get(principalField)) - .map(String.class::cast) - .map(BasicPrincipal::new); + Optional> claims = service.convertTokenToClaims(token); + if (claims.isEmpty()) { + return Optional.empty(); + } + String principal = (String) claims.get().get(principalField); + return Optional.of(Identity.forUser(userMapping.mapUser(principal)) + .withPrincipal(new BasicPrincipal(principal)) + .build()); } catch (ChallengeFailedException e) { return Optional.empty();