diff --git a/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/OidcRequestContextProperties.java b/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/OidcRequestContextProperties.java index cdf75431cd09d..d7a1f620a48af 100644 --- a/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/OidcRequestContextProperties.java +++ b/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/OidcRequestContextProperties.java @@ -4,13 +4,25 @@ public class OidcRequestContextProperties { + public static String TOKEN = "token"; + public static String TOKEN_CREDENTIAL = "token_credential"; + private final Map properties; public OidcRequestContextProperties(Map properties) { this.properties = properties; } - public Object getProperty(String name) { + public Object get(String name) { return properties.get(name); } + + public String getString(String name) { + return (String) get(name); + } + + public T get(String name, Class type) { + return type.cast(get(name)); + } + } diff --git a/extensions/oidc/runtime/pom.xml b/extensions/oidc/runtime/pom.xml index 2833e27e5ffd9..1c5145df46d56 100644 --- a/extensions/oidc/runtime/pom.xml +++ b/extensions/oidc/runtime/pom.xml @@ -50,6 +50,11 @@ quarkus-junit5-internal test + + org.awaitility + awaitility + test + diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/OidcTenantConfig.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/OidcTenantConfig.java index db4eea48e8d00..ccee7c9ca76c5 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/OidcTenantConfig.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/OidcTenantConfig.java @@ -381,6 +381,81 @@ public void setCleanUpTimerInterval(Duration cleanUpTimerInterval) { } } + /** + * Configuration for controlling how JsonWebKeySet containing verification keys should be acquired and managed. + */ + @ConfigItem + public Jwks jwks = new Jwks(); + + @ConfigGroup + public static class Jwks { + + /** + * If JWK verification keys should be fetched at the moment a connection to the OIDC provider + * is initialized. + *

+ * Disabling this property will delay the key acquisition until the moment the current token + * has to be verified. Typically it can only be necessary if the token or other telated request properties + * provide an additional context which is required to resolve the keys correctly. + */ + @ConfigItem(defaultValue = "true") + public boolean resolveEarly = true; + + /** + * Maximum number of JWK keys that can be cached. + * This property will be ignored if the {@link #resolveEarly} property is set to true. + */ + @ConfigItem(defaultValue = "10") + public int cacheSize = 10; + + /** + * Number of minutes a JWK key can be cached for. + * This property will be ignored if the {@link #resolveEarly} property is set to true. + */ + @ConfigItem(defaultValue = "10M") + public Duration cacheTimeToLive = Duration.ofMinutes(10); + + /** + * Cache timer interval. + * If this property is set then a timer will check and remove the stale entries periodically. + * This property will be ignored if the {@link #resolveEarly} property is set to true. + */ + @ConfigItem + public Optional cleanUpTimerInterval = Optional.empty(); + + public int getCacheSize() { + return cacheSize; + } + + public void setCacheSize(int cacheSize) { + this.cacheSize = cacheSize; + } + + public Duration getCacheTimeToLive() { + return cacheTimeToLive; + } + + public void setCacheTimeToLive(Duration cacheTimeToLive) { + this.cacheTimeToLive = cacheTimeToLive; + } + + public Optional getCleanUpTimerInterval() { + return cleanUpTimerInterval; + } + + public void setCleanUpTimerInterval(Duration cleanUpTimerInterval) { + this.cleanUpTimerInterval = Optional.of(cleanUpTimerInterval); + } + + public boolean isResolveEarly() { + return resolveEarly; + } + + public void setResolveEarly(boolean resolveEarly) { + this.resolveEarly = resolveEarly; + } + } + @ConfigGroup public static class Frontchannel { /** diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/BackChannelLogoutTokenCache.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/BackChannelLogoutTokenCache.java index 4150096851cf2..180c374ac7631 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/BackChannelLogoutTokenCache.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/BackChannelLogoutTokenCache.java @@ -1,103 +1,33 @@ package io.quarkus.oidc.runtime; -import java.util.Iterator; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicInteger; +import jakarta.enterprise.event.Observes; import io.quarkus.oidc.OidcTenantConfig; -import io.vertx.core.Handler; +import io.quarkus.runtime.ShutdownEvent; import io.vertx.core.Vertx; public class BackChannelLogoutTokenCache { - private OidcTenantConfig oidcConfig; - private Map cacheMap = new ConcurrentHashMap<>();; - private AtomicInteger size = new AtomicInteger(); + final MemoryCache cache; public BackChannelLogoutTokenCache(OidcTenantConfig oidcTenantConfig, Vertx vertx) { - this.oidcConfig = oidcTenantConfig; - init(vertx); - } - - private void init(Vertx vertx) { - cacheMap = new ConcurrentHashMap<>(); - if (oidcConfig.logout.backchannel.cleanUpTimerInterval.isPresent()) { - vertx.setPeriodic(oidcConfig.logout.backchannel.cleanUpTimerInterval.get().toMillis(), new Handler() { - @Override - public void handle(Long event) { - // Remove all the entries which have expired - removeInvalidEntries(); - } - }); - } + cache = new MemoryCache(vertx, oidcTenantConfig.logout.backchannel.cleanUpTimerInterval, + oidcTenantConfig.logout.backchannel.tokenCacheTimeToLive, oidcTenantConfig.logout.backchannel.tokenCacheSize); } public void addTokenVerification(String token, TokenVerificationResult result) { - if (!prepareSpaceForNewCacheEntry()) { - clearCache(); - } - cacheMap.put(token, new CacheEntry(result)); + cache.add(token, result); } public TokenVerificationResult removeTokenVerification(String token) { - CacheEntry entry = removeCacheEntry(token); - return entry == null ? null : entry.result; + return cache.remove(token); } public boolean containsTokenVerification(String token) { - return cacheMap.containsKey(token); - } - - public void clearCache() { - cacheMap.clear(); - size.set(0); - } - - private void removeInvalidEntries() { - long now = now(); - for (Iterator> it = cacheMap.entrySet().iterator(); it.hasNext();) { - Map.Entry next = it.next(); - if (isEntryExpired(next.getValue(), now)) { - it.remove(); - size.decrementAndGet(); - } - } - } - - private boolean prepareSpaceForNewCacheEntry() { - int currentSize; - do { - currentSize = size.get(); - if (currentSize == oidcConfig.logout.backchannel.tokenCacheSize) { - return false; - } - } while (!size.compareAndSet(currentSize, currentSize + 1)); - return true; + return cache.containsKey(token); } - private CacheEntry removeCacheEntry(String token) { - CacheEntry entry = cacheMap.remove(token); - if (entry != null) { - size.decrementAndGet(); - } - return entry; - } - - private boolean isEntryExpired(CacheEntry entry, long now) { - return entry.createdTime + oidcConfig.logout.backchannel.tokenCacheTimeToLive.toMillis() < now; - } - - private static long now() { - return System.currentTimeMillis(); - } - - private static class CacheEntry { - volatile TokenVerificationResult result; - long createdTime = System.currentTimeMillis(); - - public CacheEntry(TokenVerificationResult result) { - this.result = result; - } + void shutdown(@Observes ShutdownEvent event, Vertx vertx) { + cache.stopTimer(vertx); } } diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/DefaultTokenIntrospectionUserInfoCache.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/DefaultTokenIntrospectionUserInfoCache.java index 29c22f26430b5..214436d5e6988 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/DefaultTokenIntrospectionUserInfoCache.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/DefaultTokenIntrospectionUserInfoCache.java @@ -1,10 +1,6 @@ package io.quarkus.oidc.runtime; -import java.util.Collections; -import java.util.Iterator; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicInteger; +import jakarta.enterprise.event.Observes; import io.quarkus.oidc.OidcRequestContext; import io.quarkus.oidc.OidcTenantConfig; @@ -12,9 +8,8 @@ import io.quarkus.oidc.TokenIntrospectionCache; import io.quarkus.oidc.UserInfo; import io.quarkus.oidc.UserInfoCache; -import io.quarkus.oidc.runtime.OidcConfig.TokenCache; +import io.quarkus.runtime.ShutdownEvent; import io.smallrye.mutiny.Uni; -import io.vertx.core.Handler; import io.vertx.core.Vertx; /** @@ -31,43 +26,21 @@ public class DefaultTokenIntrospectionUserInfoCache implements TokenIntrospectio private static final Uni NULL_INTROSPECTION_UNI = Uni.createFrom().nullItem(); private static final Uni NULL_USERINFO_UNI = Uni.createFrom().nullItem(); - private TokenCache cacheConfig; - - private Map cacheMap; - private AtomicInteger size = new AtomicInteger(); + final MemoryCache cache; public DefaultTokenIntrospectionUserInfoCache(OidcConfig oidcConfig, Vertx vertx) { - this.cacheConfig = oidcConfig.tokenCache; - init(vertx); - } - - private void init(Vertx vertx) { - if (cacheConfig.maxSize > 0) { - cacheMap = new ConcurrentHashMap<>(); - if (cacheConfig.cleanUpTimerInterval.isPresent()) { - vertx.setPeriodic(cacheConfig.cleanUpTimerInterval.get().toMillis(), new Handler() { - @Override - public void handle(Long event) { - // Remove all the entries which have expired - removeInvalidEntries(); - } - }); - } - } else { - cacheMap = Collections.emptyMap(); - } + cache = new MemoryCache(vertx, oidcConfig.tokenCache.cleanUpTimerInterval, + oidcConfig.tokenCache.timeToLive, oidcConfig.tokenCache.maxSize); } @Override public Uni addIntrospection(String token, TokenIntrospection introspection, OidcTenantConfig oidcTenantConfig, OidcRequestContext requestContext) { - if (cacheConfig.maxSize > 0) { - CacheEntry entry = findValidCacheEntry(token); - if (entry != null) { - entry.introspection = introspection; - } else if (prepareSpaceForNewCacheEntry()) { - cacheMap.put(token, new CacheEntry(introspection)); - } + CacheEntry entry = cache.get(token); + if (entry != null) { + entry.introspection = introspection; + } else { + cache.add(token, new CacheEntry(introspection)); } return CodeAuthenticationMechanism.VOID_UNI; @@ -76,20 +49,18 @@ public Uni addIntrospection(String token, TokenIntrospection introspection @Override public Uni getIntrospection(String token, OidcTenantConfig oidcConfig, OidcRequestContext requestContext) { - CacheEntry entry = findValidCacheEntry(token); + CacheEntry entry = cache.get(token); return entry == null ? NULL_INTROSPECTION_UNI : Uni.createFrom().item(entry.introspection); } @Override public Uni addUserInfo(String token, UserInfo userInfo, OidcTenantConfig oidcTenantConfig, OidcRequestContext requestContext) { - if (cacheConfig.maxSize > 0) { - CacheEntry entry = findValidCacheEntry(token); - if (entry != null) { - entry.userInfo = userInfo; - } else if (prepareSpaceForNewCacheEntry()) { - cacheMap.put(token, new CacheEntry(userInfo)); - } + CacheEntry entry = cache.get(token); + if (entry != null) { + entry.userInfo = userInfo; + } else { + cache.add(token, new CacheEntry(userInfo)); } return CodeAuthenticationMechanism.VOID_UNI; @@ -98,67 +69,13 @@ public Uni addUserInfo(String token, UserInfo userInfo, OidcTenantConfig o @Override public Uni getUserInfo(String token, OidcTenantConfig oidcConfig, OidcRequestContext requestContext) { - CacheEntry entry = findValidCacheEntry(token); + CacheEntry entry = cache.get(token); return entry == null ? NULL_USERINFO_UNI : Uni.createFrom().item(entry.userInfo); } - public int getCacheSize() { - return cacheMap.size(); - } - - public void clearCache() { - cacheMap.clear(); - size.set(0); - } - - private void removeInvalidEntries() { - long now = now(); - for (Iterator> it = cacheMap.entrySet().iterator(); it.hasNext();) { - Map.Entry next = it.next(); - if (isEntryExpired(next.getValue(), now)) { - it.remove(); - size.decrementAndGet(); - } - } - } - - private boolean prepareSpaceForNewCacheEntry() { - int currentSize; - do { - currentSize = size.get(); - if (currentSize == cacheConfig.maxSize) { - return false; - } - } while (!size.compareAndSet(currentSize, currentSize + 1)); - return true; - } - - private CacheEntry findValidCacheEntry(String token) { - CacheEntry entry = cacheMap.get(token); - if (entry != null) { - long now = now(); - if (isEntryExpired(entry, now)) { - // Entry has expired, remote introspection will be required - entry = null; - cacheMap.remove(token); - size.decrementAndGet(); - } - } - return entry; - } - - private boolean isEntryExpired(CacheEntry entry, long now) { - return entry.createdTime + cacheConfig.timeToLive.toMillis() < now; - } - - private static long now() { - return System.currentTimeMillis(); - } - private static class CacheEntry { volatile TokenIntrospection introspection; volatile UserInfo userInfo; - long createdTime = System.currentTimeMillis(); public CacheEntry(TokenIntrospection introspection) { this.introspection = introspection; @@ -168,4 +85,17 @@ public CacheEntry(UserInfo userInfo) { this.userInfo = userInfo; } } + + public void clearCache() { + cache.clearCache(); + } + + public int getCacheSize() { + return cache.getCacheSize(); + } + + void shutdown(@Observes ShutdownEvent event, Vertx vertx) { + cache.stopTimer(vertx); + } + } diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/DynamicVerificationKeyResolver.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/DynamicVerificationKeyResolver.java new file mode 100644 index 0000000000000..f9b9eb7b2a03e --- /dev/null +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/DynamicVerificationKeyResolver.java @@ -0,0 +1,177 @@ +package io.quarkus.oidc.runtime; + +import java.security.Key; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import jakarta.enterprise.event.Observes; + +import org.jboss.logging.Logger; +import org.jose4j.jws.JsonWebSignature; +import org.jose4j.jwx.HeaderParameterNames; +import org.jose4j.jwx.JsonWebStructure; +import org.jose4j.keys.resolvers.VerificationKeyResolver; +import org.jose4j.lang.UnresolvableKeyException; + +import io.quarkus.oidc.OidcTenantConfig; +import io.quarkus.oidc.common.OidcRequestContextProperties; +import io.quarkus.runtime.ShutdownEvent; +import io.quarkus.security.credential.TokenCredential; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Vertx; +import io.vertx.core.json.JsonObject; + +public class DynamicVerificationKeyResolver { + private static final Logger LOG = Logger.getLogger(DynamicVerificationKeyResolver.class); + + private final OidcProviderClient client; + private final MemoryCache cache; + + public DynamicVerificationKeyResolver(OidcProviderClient client, OidcTenantConfig config) { + this.client = client; + this.cache = new MemoryCache(client.getVertx(), config.jwks.cleanUpTimerInterval, + config.jwks.cacheTimeToLive, config.jwks.cacheSize); + } + + public Uni resolve(TokenCredential tokenCred) { + JsonObject headers = OidcUtils.decodeJwtHeaders(tokenCred.getToken()); + Key key = findKeyInTheCache(headers); + if (key != null) { + return Uni.createFrom().item(new SingleKeyVerificationKeyResolver(key)); + } + + return client.getJsonWebKeySet(new OidcRequestContextProperties( + Map.of(OidcRequestContextProperties.TOKEN, tokenCred.getToken(), + OidcRequestContextProperties.TOKEN_CREDENTIAL, tokenCred))) + .onItem().transformToUni(new Function>() { + + @Override + public Uni apply(JsonWebKeySet jwks) { + Key newKey = null; + // Try 'kid' first + String kid = headers.getString(HeaderParameterNames.KEY_ID); + if (kid != null) { + newKey = getKeyWithId(jwks, kid); + if (newKey == null) { + // if `kid` was set then the key must exist + return Uni.createFrom().failure( + new UnresolvableKeyException(String.format("JWK with kid '%s' is not available", kid))); + } else { + cache.add(kid, newKey); + } + } + + String thumbprint = null; + if (newKey == null) { + thumbprint = headers.getString(HeaderParameterNames.X509_CERTIFICATE_SHA256_THUMBPRINT); + if (thumbprint != null) { + newKey = getKeyWithS256Thumbprint(jwks, thumbprint); + if (newKey == null) { + // if only `x5tS256` was set then the key must exist + return Uni.createFrom().failure( + new UnresolvableKeyException(String.format( + "JWK with the SHA256 certificate thumbprint '%s' is not available", + thumbprint))); + } else { + cache.add(thumbprint, newKey); + } + } + } + + if (newKey == null) { + thumbprint = headers.getString(HeaderParameterNames.X509_CERTIFICATE_THUMBPRINT); + if (thumbprint != null) { + newKey = getKeyWithThumbprint(jwks, thumbprint); + if (newKey == null) { + // if only `x5t` was set then the key must exist + return Uni.createFrom().failure(new UnresolvableKeyException( + String.format("JWK with the certificate thumbprint '%s' is not available", + thumbprint))); + } else { + cache.add(thumbprint, newKey); + } + } + } + + if (newKey == null && kid == null && thumbprint == null) { + newKey = jwks.getKeyWithoutKeyIdAndThumbprint("RSA"); + } + + if (newKey == null) { + return Uni.createFrom().failure(new UnresolvableKeyException( + String.format( + "JWK is not available, neither 'kid' nor 'x5t#S256' nor 'x5t' token headers are set", + kid))); + } else { + return Uni.createFrom().item(new SingleKeyVerificationKeyResolver(newKey)); + } + } + + }); + } + + private static Key getKeyWithId(JsonWebKeySet jwks, String kid) { + if (kid != null) { + return jwks.getKeyWithId(kid); + } else { + LOG.debug("Token 'kid' header is not set"); + return null; + } + } + + private Key getKeyWithThumbprint(JsonWebKeySet jwks, String thumbprint) { + if (thumbprint != null) { + return jwks.getKeyWithThumbprint(thumbprint); + } else { + LOG.debug("Token 'x5t' header is not set"); + return null; + } + } + + private Key getKeyWithS256Thumbprint(JsonWebKeySet jwks, String thumbprint) { + if (thumbprint != null) { + return jwks.getKeyWithS256Thumbprint(thumbprint); + } else { + LOG.debug("Token 'x5tS256' header is not set"); + return null; + } + } + + private Key findKeyInTheCache(JsonObject headers) { + String kid = headers.getString(HeaderParameterNames.KEY_ID); + if (kid != null && cache.containsKey(kid)) { + return cache.get(kid); + } + String thumbprint = headers.getString(HeaderParameterNames.X509_CERTIFICATE_SHA256_THUMBPRINT); + if (thumbprint != null && cache.containsKey(thumbprint)) { + return cache.get(thumbprint); + } + + thumbprint = headers.getString(HeaderParameterNames.X509_CERTIFICATE_THUMBPRINT); + if (thumbprint != null && cache.containsKey(thumbprint)) { + return cache.get(thumbprint); + } + + return null; + } + + static class SingleKeyVerificationKeyResolver implements VerificationKeyResolver { + + private Key key; + + SingleKeyVerificationKeyResolver(Key key) { + this.key = key; + } + + @Override + public Key resolveKey(JsonWebSignature jws, List nestingContext) + throws UnresolvableKeyException { + return key; + } + } + + void shutdown(@Observes ShutdownEvent event, Vertx vertx) { + cache.stopTimer(vertx); + } +} diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/JsonWebKeySet.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/JsonWebKeySet.java index 5e80ddfb94b17..dedfe32bf1156 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/JsonWebKeySet.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/JsonWebKeySet.java @@ -9,11 +9,8 @@ import org.jose4j.jwk.JsonWebKey; import org.jose4j.jwk.PublicJsonWebKey; -import org.jose4j.jws.JsonWebSignature; -import org.jose4j.lang.InvalidAlgorithmException; import org.jose4j.lang.JoseException; -import io.quarkus.logging.Log; import io.quarkus.oidc.OIDCException; public class JsonWebKeySet { @@ -86,13 +83,8 @@ public Key getKeyWithS256Thumbprint(String x5tS256) { return keysWithS256Thumbprints.get(x5tS256); } - public Key getKeyWithoutKeyIdAndThumbprint(JsonWebSignature jws) { - try { - List keys = keysWithoutKeyIdAndThumbprint.get(jws.getKeyType()); - return keys == null || keys.size() != 1 ? null : keys.get(0); - } catch (InvalidAlgorithmException ex) { - Log.debug("Token 'alg'(algorithm) header value is invalid", ex); - return null; - } + public Key getKeyWithoutKeyIdAndThumbprint(String keyType) { + List keys = keysWithoutKeyIdAndThumbprint.get(keyType); + return keys == null || keys.size() != 1 ? null : keys.get(0); } } diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/MemoryCache.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/MemoryCache.java new file mode 100644 index 0000000000000..dd8e4943ef029 --- /dev/null +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/MemoryCache.java @@ -0,0 +1,135 @@ +package io.quarkus.oidc.runtime; + +import java.time.Duration; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +import io.vertx.core.Handler; +import io.vertx.core.Vertx; + +public class MemoryCache { + private volatile Long timerId = null; + + private final Map> cacheMap = new ConcurrentHashMap<>(); + private AtomicInteger size = new AtomicInteger(); + private final Duration cacheTimeToLive; + private final int cacheSize; + + public MemoryCache(Vertx vertx, Optional cleanUpTimerInterval, + Duration cacheTimeToLive, int cacheSize) { + this.cacheTimeToLive = cacheTimeToLive; + this.cacheSize = cacheSize; + init(vertx, cleanUpTimerInterval); + } + + private void init(Vertx vertx, Optional cleanUpTimerInterval) { + if (cleanUpTimerInterval.isPresent()) { + timerId = vertx.setPeriodic(cleanUpTimerInterval.get().toMillis(), new Handler() { + @Override + public void handle(Long event) { + // Remove all the entries which have expired + removeInvalidEntries(); + } + }); + } + } + + public void add(String key, T result) { + if (cacheSize > 0) { + if (!prepareSpaceForNewCacheEntry()) { + clearCache(); + } + cacheMap.put(key, new CacheEntry(result)); + } + } + + public T remove(String key) { + CacheEntry entry = removeCacheEntry(key); + return entry == null ? null : entry.result; + } + + public T get(String key) { + CacheEntry entry = cacheMap.get(key); + return entry == null ? null : entry.result; + } + + public boolean containsKey(String key) { + return cacheMap.containsKey(key); + } + + private void removeInvalidEntries() { + long now = now(); + for (Iterator>> it = cacheMap.entrySet().iterator(); it.hasNext();) { + Map.Entry> next = it.next(); + if (next != null) { + if (isEntryExpired(next.getValue(), now)) { + try { + it.remove(); + size.decrementAndGet(); + } catch (IllegalStateException ex) { + // continue + } + } + } + } + } + + private boolean prepareSpaceForNewCacheEntry() { + int currentSize; + do { + currentSize = size.get(); + if (currentSize == cacheSize) { + return false; + } + } while (!size.compareAndSet(currentSize, currentSize + 1)); + return true; + } + + private CacheEntry removeCacheEntry(String token) { + CacheEntry entry = cacheMap.remove(token); + if (entry != null) { + size.decrementAndGet(); + } + return entry; + } + + private boolean isEntryExpired(CacheEntry entry, long now) { + return entry.createdTime + cacheTimeToLive.toMillis() < now; + } + + private static long now() { + return System.currentTimeMillis(); + } + + private static class CacheEntry { + volatile T result; + long createdTime = System.currentTimeMillis(); + + public CacheEntry(T result) { + this.result = result; + } + } + + public int getCacheSize() { + return cacheMap.size(); + } + + public void clearCache() { + cacheMap.clear(); + size.set(0); + } + + public void stopTimer(Vertx vertx) { + if (timerId != null && vertx.cancelTimer(timerId)) { + timerId = null; + } + } + + public boolean isTimerRunning() { + return timerId != null; + } + +} diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcIdentityProvider.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcIdentityProvider.java index 26d053339d97a..4cc968508c72f 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcIdentityProvider.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcIdentityProvider.java @@ -142,7 +142,7 @@ public Uni apply(UserInfo userInfo, Throwable t) { primaryTokenUni = verifySelfSignedTokenUni(resolvedContext, request.getToken().getToken()); } } else { - primaryTokenUni = verifyTokenUni(requestData, resolvedContext, request.getToken().getToken(), + primaryTokenUni = verifyTokenUni(requestData, resolvedContext, request.getToken(), isIdToken(request), null); } @@ -192,7 +192,7 @@ public Uni apply(TokenVerificationResult codeAccessToken, Thro } Uni tokenUni = verifyTokenUni(requestData, resolvedContext, - request.getToken().getToken(), + request.getToken(), false, userInfo); return tokenUni.onItemOrFailure() @@ -421,14 +421,15 @@ private Uni verifyCodeFlowAccessTokenUni(Map verifyTokenUni(Map requestData, TenantConfigContext resolvedContext, - String token, boolean enforceAudienceVerification, UserInfo userInfo) { + TokenCredential tokenCred, boolean enforceAudienceVerification, UserInfo userInfo) { + final String token = tokenCred.getToken(); if (OidcUtils.isOpaqueToken(token)) { if (!resolvedContext.oidcConfig.token.allowOpaqueTokenIntrospection) { LOG.debug("Token is opaque but the opaque token introspection is not allowed"); @@ -452,7 +453,7 @@ private Uni verifyTokenUni(Map requestD // Verify JWT token with the remote introspection LOG.debug("Starting the JWT token introspection"); return introspectTokenUni(resolvedContext, token, false); - } else { + } else if (resolvedContext.oidcConfig.jwks.resolveEarly) { // Verify JWT token with the local JWK keys with a possible remote introspection fallback final String nonce = (String) requestData.get(OidcConstants.NONCE); try { @@ -470,6 +471,10 @@ private Uni verifyTokenUni(Map requestD return Uni.createFrom().failure(t); } } + } else { + final String nonce = (String) requestData.get(OidcConstants.NONCE); + return resolveJwksAndVerifyTokenUni(resolvedContext, tokenCred, enforceAudienceVerification, + resolvedContext.oidcConfig.token.isSubjectRequired(), nonce); } } @@ -488,6 +493,15 @@ private Uni refreshJwksAndVerifyTokenUni(TenantConfigCo .recoverWithUni(f -> introspectTokenUni(resolvedContext, token, true)); } + private Uni resolveJwksAndVerifyTokenUni(TenantConfigContext resolvedContext, + TokenCredential tokenCred, + boolean enforceAudienceVerification, boolean subjectRequired, String nonce) { + return resolvedContext.provider + .getKeyResolverAndVerifyJwtToken(tokenCred, enforceAudienceVerification, subjectRequired, nonce) + .onFailure(f -> fallbackToIntrospectionIfNoMatchingKey(f, resolvedContext)) + .recoverWithUni(f -> introspectTokenUni(resolvedContext, tokenCred.getToken(), true)); + } + private static boolean fallbackToIntrospectionIfNoMatchingKey(Throwable f, TenantConfigContext resolvedContext) { if (!(f.getCause() instanceof UnresolvableKeyException)) { LOG.debug("Local JWT token verification has failed, skipping the token introspection"); diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProvider.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProvider.java index 32bbc806e2d6c..8d26b8aae936c 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProvider.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProvider.java @@ -27,8 +27,10 @@ import org.jose4j.jwx.HeaderParameterNames; import org.jose4j.jwx.JsonWebStructure; import org.jose4j.keys.resolvers.VerificationKeyResolver; +import org.jose4j.lang.InvalidAlgorithmException; import org.jose4j.lang.UnresolvableKeyException; +import io.quarkus.logging.Log; import io.quarkus.oidc.AuthorizationCodeTokens; import io.quarkus.oidc.OIDCException; import io.quarkus.oidc.OidcConfigurationMetadata; @@ -38,6 +40,7 @@ import io.quarkus.oidc.UserInfo; import io.quarkus.oidc.common.runtime.OidcConstants; import io.quarkus.security.AuthenticationFailedException; +import io.quarkus.security.credential.TokenCredential; import io.smallrye.jwt.algorithm.SignatureAlgorithm; import io.smallrye.jwt.util.KeyUtils; import io.smallrye.mutiny.Uni; @@ -64,6 +67,7 @@ public class OidcProvider implements Closeable { final OidcProviderClient client; final RefreshableVerificationKeyResolver asymmetricKeyResolver; + final DynamicVerificationKeyResolver keyResolverProvider; final OidcTenantConfig oidcConfig; final TokenCustomizer tokenCustomizer; final String issuer; @@ -83,7 +87,11 @@ public OidcProvider(OidcProviderClient client, OidcTenantConfig oidcConfig, Json this.tokenCustomizer = tokenCustomizer; this.asymmetricKeyResolver = jwks == null ? null : new JsonWebKeyResolver(jwks, oidcConfig.token.forcedJwkRefreshInterval); - + if (client != null && oidcConfig != null && !oidcConfig.jwks.resolveEarly) { + this.keyResolverProvider = new DynamicVerificationKeyResolver(client, oidcConfig); + } else { + this.keyResolverProvider = null; + } this.issuer = checkIssuerProp(); this.audience = checkAudienceProp(); this.requiredClaims = checkRequiredClaimsProp(); @@ -96,6 +104,7 @@ public OidcProvider(String publicKeyEnc, OidcTenantConfig oidcConfig, Key tokenD this.oidcConfig = oidcConfig; this.tokenCustomizer = TokenCustomizerFinder.find(oidcConfig); this.asymmetricKeyResolver = new LocalPublicKeyResolver(publicKeyEnc); + this.keyResolverProvider = null; this.issuer = checkIssuerProp(); this.audience = checkAudienceProp(); this.requiredClaims = checkRequiredClaimsProp(); @@ -282,6 +291,30 @@ public Uni apply(Void v) { }); } + public Uni getKeyResolverAndVerifyJwtToken(TokenCredential tokenCred, + boolean enforceAudienceVerification, + boolean subjectRequired, String nonce) { + return keyResolverProvider.resolve(tokenCred).onItem() + .transformToUni(new Function>() { + + @Override + public Uni apply(VerificationKeyResolver resolver) { + try { + return Uni.createFrom() + .item(verifyJwtTokenInternal(customizeJwtToken(tokenCred.getToken()), + enforceAudienceVerification, + subjectRequired, nonce, + (requiredAlgorithmConstraints != null ? requiredAlgorithmConstraints + : ASYMMETRIC_ALGORITHM_CONSTRAINTS), + resolver, true)); + } catch (Throwable t) { + return Uni.createFrom().failure(t); + } + } + + }); + } + public Uni introspectToken(String token, boolean fallbackFromJwkMatch) { if (client.getMetadata().getIntrospectionUri() == null) { String errorMessage = String.format("Token issued to client %s " @@ -380,7 +413,7 @@ public Key resolveKey(JsonWebSignature jws, List nestingContex // Try 'kid' first String kid = jws.getKeyIdHeaderValue(); if (kid != null) { - key = getKeyWithId(jws, kid); + key = getKeyWithId(kid); if (key == null) { // if `kid` was set then the key must exist throw new UnresolvableKeyException(String.format("JWK with kid '%s' is not available", kid)); @@ -389,31 +422,35 @@ public Key resolveKey(JsonWebSignature jws, List nestingContex String thumbprint = null; if (key == null) { - thumbprint = jws.getHeader(HeaderParameterNames.X509_CERTIFICATE_THUMBPRINT); + thumbprint = jws.getHeader(HeaderParameterNames.X509_CERTIFICATE_SHA256_THUMBPRINT); if (thumbprint != null) { - key = getKeyWithThumbprint(jws, thumbprint); + key = getKeyWithS256Thumbprint(thumbprint); if (key == null) { - // if only `x5t` was set then the key must exist + // if only `x5tS256` was set then the key must exist throw new UnresolvableKeyException( - String.format("JWK with the certificate thumbprint '%s' is not available", thumbprint)); + String.format("JWK with the SHA256 certificate thumbprint '%s' is not available", thumbprint)); } } } if (key == null) { - thumbprint = jws.getHeader(HeaderParameterNames.X509_CERTIFICATE_SHA256_THUMBPRINT); + thumbprint = jws.getHeader(HeaderParameterNames.X509_CERTIFICATE_THUMBPRINT); if (thumbprint != null) { - key = getKeyWithS256Thumbprint(jws, thumbprint); + key = getKeyWithThumbprint(thumbprint); if (key == null) { - // if only `x5tS256` was set then the key must exist + // if only `x5t` was set then the key must exist throw new UnresolvableKeyException( - String.format("JWK with the SHA256 certificate thumbprint '%s' is not available", thumbprint)); + String.format("JWK with the certificate thumbprint '%s' is not available", thumbprint)); } } } if (key == null && kid == null && thumbprint == null) { - key = jwks.getKeyWithoutKeyIdAndThumbprint(jws); + try { + key = jwks.getKeyWithoutKeyIdAndThumbprint(jws.getKeyType()); + } catch (InvalidAlgorithmException ex) { + Log.debug("Token 'alg'(algorithm) header value is invalid", ex); + } } if (key == null) { @@ -425,7 +462,7 @@ public Key resolveKey(JsonWebSignature jws, List nestingContex } } - private Key getKeyWithId(JsonWebSignature jws, String kid) { + private Key getKeyWithId(String kid) { if (kid != null) { return jwks.getKeyWithId(kid); } else { @@ -434,7 +471,7 @@ private Key getKeyWithId(JsonWebSignature jws, String kid) { } } - private Key getKeyWithThumbprint(JsonWebSignature jws, String thumbprint) { + private Key getKeyWithThumbprint(String thumbprint) { if (thumbprint != null) { return jwks.getKeyWithThumbprint(thumbprint); } else { @@ -443,7 +480,7 @@ private Key getKeyWithThumbprint(JsonWebSignature jws, String thumbprint) { } } - private Key getKeyWithS256Thumbprint(JsonWebSignature jws, String thumbprint) { + private Key getKeyWithS256Thumbprint(String thumbprint) { if (thumbprint != null) { return jwks.getKeyWithS256Thumbprint(thumbprint); } else { @@ -456,15 +493,16 @@ public Uni refresh() { final long now = now(); if (now > lastForcedRefreshTime + forcedJwksRefreshIntervalMilliSecs) { lastForcedRefreshTime = now; - return client.getJsonWebKeySet().onItem().transformToUni(new Function>() { + return client.getJsonWebKeySet(null).onItem() + .transformToUni(new Function>() { - @Override - public Uni apply(JsonWebKeySet t) { - jwks = t; - return Uni.createFrom().voidItem(); - } + @Override + public Uni apply(JsonWebKeySet t) { + jwks = t; + return Uni.createFrom().voidItem(); + } - }); + }); } else { return Uni.createFrom().voidItem(); } diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProviderClient.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProviderClient.java index ed4abe2aafecf..204c38984259c 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProviderClient.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProviderClient.java @@ -15,12 +15,14 @@ import io.quarkus.oidc.OidcTenantConfig; import io.quarkus.oidc.TokenIntrospection; import io.quarkus.oidc.UserInfo; +import io.quarkus.oidc.common.OidcRequestContextProperties; import io.quarkus.oidc.common.OidcRequestFilter; import io.quarkus.oidc.common.runtime.OidcCommonUtils; import io.quarkus.oidc.common.runtime.OidcConstants; import io.quarkus.oidc.common.runtime.OidcEndpointAccessException; import io.smallrye.mutiny.Uni; import io.smallrye.mutiny.groups.UniOnItem; +import io.vertx.core.Vertx; import io.vertx.core.http.HttpHeaders; import io.vertx.core.json.JsonObject; import io.vertx.mutiny.core.MultiMap; @@ -40,6 +42,7 @@ public class OidcProviderClient implements Closeable { private static final String APPLICATION_JSON = "application/json"; private final WebClient client; + private final Vertx vertx; private final OidcConfigurationMetadata metadata; private final OidcTenantConfig oidcConfig; private final String clientSecretBasicAuthScheme; @@ -48,10 +51,12 @@ public class OidcProviderClient implements Closeable { private final List filters; public OidcProviderClient(WebClient client, + Vertx vertx, OidcConfigurationMetadata metadata, OidcTenantConfig oidcConfig, List filters) { this.client = client; + this.vertx = vertx; this.metadata = metadata; this.oidcConfig = oidcConfig; this.clientSecretBasicAuthScheme = OidcCommonUtils.initClientSecretBasicAuth(oidcConfig); @@ -74,14 +79,14 @@ public OidcConfigurationMetadata getMetadata() { return metadata; } - public Uni getJsonWebKeySet() { - return filter(client.getAbs(metadata.getJsonWebKeySetUri()), null).send().onItem() + public Uni getJsonWebKeySet(OidcRequestContextProperties contextProperties) { + return filter(client.getAbs(metadata.getJsonWebKeySetUri()), null, contextProperties).send().onItem() .transform(resp -> getJsonWebKeySet(resp)); } public Uni getUserInfo(String token) { LOG.debugf("Get UserInfo on: %s auth: %s", metadata.getUserInfoUri(), OidcConstants.BEARER_SCHEME + " " + token); - return filter(client.getAbs(metadata.getUserInfoUri()), null) + return filter(client.getAbs(metadata.getUserInfoUri()), null, null) .putHeader(AUTHORIZATION_HEADER, OidcConstants.BEARER_SCHEME + " " + token) .send().onItem().transform(resp -> getUserInfo(resp)); } @@ -163,7 +168,7 @@ private UniOnItem> getHttpResponse(String uri, MultiMap for LOG.debugf("Get token on: %s params: %s headers: %s", metadata.getTokenUri(), formBody, request.headers()); // Retry up to three times with a one-second delay between the retries if the connection is closed. Buffer buffer = OidcCommonUtils.encodeForm(formBody); - Uni> response = filter(request, buffer).sendBuffer(buffer) + Uni> response = filter(request, buffer, null).sendBuffer(buffer) .onFailure(ConnectException.class) .retry() .atMost(oidcConfig.connectionRetryCount).onFailure().transform(t -> t.getCause()); @@ -219,10 +224,15 @@ public Key getClientJwtKey() { return clientJwtKey; } - private HttpRequest filter(HttpRequest request, Buffer body) { + private HttpRequest filter(HttpRequest request, Buffer body, + OidcRequestContextProperties contextProperties) { for (OidcRequestFilter filter : filters) { - filter.filter(request, body, null); + filter.filter(request, body, contextProperties); } return request; } + + public Vertx getVertx() { + return vertx; + } } diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcRecorder.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcRecorder.java index d169d2bcd2079..50264f617dfd5 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcRecorder.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcRecorder.java @@ -352,17 +352,16 @@ protected static Uni createOidcProvider(OidcTenantConfig oidcConfi .transformToUni(new Function>() { @Override public Uni apply(OidcProviderClient client) { - if (client.getMetadata().getJsonWebKeySetUri() != null + if (oidcConfig.jwks.resolveEarly + && client.getMetadata().getJsonWebKeySetUri() != null && !oidcConfig.token.requireJwtIntrospectionOnly) { return getJsonWebSetUni(client, oidcConfig).onItem() .transform(new Function() { - @Override public OidcProvider apply(JsonWebKeySet jwks) { return new OidcProvider(client, oidcConfig, jwks, readTokenDecryptionKey(oidcConfig)); } - }); } else { return Uni.createFrom() @@ -405,7 +404,7 @@ private static Key readTokenDecryptionKey(OidcTenantConfig oidcConfig) { protected static Uni getJsonWebSetUni(OidcProviderClient client, OidcTenantConfig oidcConfig) { if (!oidcConfig.isDiscoveryEnabled().orElse(true)) { final long connectionDelayInMillisecs = OidcCommonUtils.getConnectionDelayInMillis(oidcConfig); - return client.getJsonWebKeySet().onFailure(OidcCommonUtils.oidcEndpointNotAvailable()) + return client.getJsonWebKeySet(null).onFailure(OidcCommonUtils.oidcEndpointNotAvailable()) .retry() .withBackOff(OidcCommonUtils.CONNECTION_BACKOFF_DURATION, OidcCommonUtils.CONNECTION_BACKOFF_DURATION) .expireIn(connectionDelayInMillisecs) @@ -419,7 +418,7 @@ public Throwable apply(Throwable t) { .onFailure() .invoke(client::close); } else { - return client.getJsonWebKeySet(); + return client.getJsonWebKeySet(null); } } @@ -479,7 +478,7 @@ public Uni apply(OidcConfigurationMetadata metadata, Throwab + " Use 'quarkus.oidc.user-info-path' if the discovery is disabled.")); } return Uni.createFrom() - .item(new OidcProviderClient(client, metadata, oidcConfig, clientRequestFilters)); + .item(new OidcProviderClient(client, vertx, metadata, oidcConfig, clientRequestFilters)); } }); diff --git a/extensions/oidc/runtime/src/test/java/io/quarkus/oidc/runtime/MemoryCacheTest.java b/extensions/oidc/runtime/src/test/java/io/quarkus/oidc/runtime/MemoryCacheTest.java new file mode 100644 index 0000000000000..6e0d7f082f6c1 --- /dev/null +++ b/extensions/oidc/runtime/src/test/java/io/quarkus/oidc/runtime/MemoryCacheTest.java @@ -0,0 +1,101 @@ +package io.quarkus.oidc.runtime; + +import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.Callable; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; + +import io.vertx.core.Vertx; + +public class MemoryCacheTest { + + static Vertx vertx = Vertx.vertx(); + + @AfterAll + public static void closeVertxClient() { + if (vertx != null) { + vertx.close().toCompletionStage().toCompletableFuture().join(); + vertx = null; + } + } + + @Test + public void testCache() throws Exception { + + MemoryCache cache = new MemoryCache(vertx, + // timer interval + Optional.of(Duration.ofSeconds(1)), + // entry is valid for 3 seconds + Duration.ofSeconds(2), + // max cache size + 2); + cache.add("1", new Bean("1")); + cache.add("2", new Bean("2")); + assertEquals(2, cache.getCacheSize()); + + assertEquals("1", cache.get("1").name); + assertEquals("2", cache.get("2").name); + + assertEquals("1", cache.remove("1").name); + assertNull(cache.get("1")); + assertEquals("2", cache.get("2").name); + assertEquals(1, cache.getCacheSize()); + + assertTrue(cache.isTimerRunning()); + + await().atMost(Duration.ofSeconds(5)).until(new Callable() { + + @Override + public Boolean call() throws Exception { + return cache.getCacheSize() == 0; + } + + }); + + cache.stopTimer(vertx); + assertFalse(cache.isTimerRunning()); + } + + @Test + public void testAddWhenMaxCacheSizeIsReached() throws Exception { + + MemoryCache cache = new MemoryCache(vertx, + // timer interval + Optional.empty(), + // entry is valid for 3 seconds + Duration.ofSeconds(3), + // max cache size + 2); + assertFalse(cache.isTimerRunning()); + + cache.add("1", new Bean("1")); + cache.add("2", new Bean("2")); + assertEquals(2, cache.getCacheSize()); + + // Currently, if the cache is full and a new entry has to be added, then the whole cache is cleared + // It can be optimized to remove the oldest entry only in the future + + cache.add("3", new Bean("3")); + assertEquals(1, cache.getCacheSize()); + + assertNull(cache.get("1")); + assertNull(cache.get("2")); + assertEquals("3", cache.get("3").name); + } + + static class Bean { + String name; + + Bean(String name) { + this.name = name; + } + } +} diff --git a/integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcRequestCustomizer.java b/integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcRequestCustomizer.java index b26cb0aa8049c..0f76995ecd0ed 100644 --- a/integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcRequestCustomizer.java +++ b/integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcRequestCustomizer.java @@ -3,6 +3,7 @@ import jakarta.enterprise.context.ApplicationScoped; import io.quarkus.arc.Unremovable; +import io.quarkus.oidc.AccessTokenCredential; import io.quarkus.oidc.common.OidcRequestContextProperties; import io.quarkus.oidc.common.OidcRequestFilter; import io.vertx.core.http.HttpMethod; @@ -18,7 +19,22 @@ public void filter(HttpRequest request, Buffer buffer, OidcRequestContex HttpMethod method = request.method(); String uri = request.uri(); if (method == HttpMethod.GET && uri.endsWith("/auth/azure/jwk")) { - request.putHeader("Authorization", "ID token"); + String token = contextProps.getString(OidcRequestContextProperties.TOKEN); + AccessTokenCredential tokenCred = contextProps.get(OidcRequestContextProperties.TOKEN_CREDENTIAL, + AccessTokenCredential.class); + // or + // IdTokenCredential tokenCred = contextProps.get(OidcRequestContextProperties.TOKEN_CREDENTIAL, + // IdTokenCredential.class); + // or + // TokenCredential tokenCred = contextProps.get(OidcRequestContextProperties.TOKEN_CREDENTIAL, + // TokenCredential.class); + // if either access or ID token has to be verified and check is it an instanceof + // AccessTokenCredential or IdTokenCredential + // or simply + // String token = contextProps.getString(OidcRequestContextProperties.TOKEN); + if (token.equals(tokenCred.getToken())) { + request.putHeader("Authorization", "Access token: " + token); + } } } diff --git a/integration-tests/oidc-wiremock/src/main/resources/application.properties b/integration-tests/oidc-wiremock/src/main/resources/application.properties index f806a8948240c..1a0e9556492de 100644 --- a/integration-tests/oidc-wiremock/src/main/resources/application.properties +++ b/integration-tests/oidc-wiremock/src/main/resources/application.properties @@ -130,6 +130,7 @@ quarkus.oidc.bearer-azure.provider=microsoft quarkus.oidc.bearer-azure.application-type=service quarkus.oidc.bearer-azure.discovery-enabled=false quarkus.oidc.bearer-azure.jwks-path=${keycloak.url}/azure/jwk +quarkus.oidc.bearer-azure.jwks.resolve-early=false quarkus.oidc.bearer-azure.token.lifespan-grace=2147483647 quarkus.oidc.bearer-azure.token.customizer-name=azure-access-token-customizer diff --git a/integration-tests/oidc-wiremock/src/test/java/io/quarkus/it/keycloak/BearerTokenAuthorizationTest.java b/integration-tests/oidc-wiremock/src/test/java/io/quarkus/it/keycloak/BearerTokenAuthorizationTest.java index 4c0b332ce82ad..4e31443081776 100644 --- a/integration-tests/oidc-wiremock/src/test/java/io/quarkus/it/keycloak/BearerTokenAuthorizationTest.java +++ b/integration-tests/oidc-wiremock/src/test/java/io/quarkus/it/keycloak/BearerTokenAuthorizationTest.java @@ -50,11 +50,11 @@ public void testSecureAccessSuccessPreferredUsername() { @Test public void testAccessResourceAzure() throws Exception { + String azureToken = readFile("token.txt"); String azureJwk = readFile("jwks.json"); wireMockServer.stubFor(WireMock.get("/auth/azure/jwk") - .withHeader("Authorization", matching("ID token")) + .withHeader("Authorization", matching("Access token: " + azureToken)) .willReturn(WireMock.aResponse().withBody(azureJwk))); - String azureToken = readFile("token.txt"); RestAssured.given().auth().oauth2(azureToken) .when().get("/api/admin/bearer-azure") .then()