Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing a problem with potential dirty read of a token document on token refresh #64031

Merged
merged 5 commits into from
Oct 26, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ protected void doExecute(Task task, DelegatePkiAuthenticationRequest request,
tokenService.createOAuth2Tokens(authentication, delegateeAuthentication, Map.of(), false,
ActionListener.wrap(tuple -> {
final TimeValue expiresIn = tokenService.getExpirationDelay();
listener.onResponse(new DelegatePkiAuthenticationResponse(tuple.v1(), expiresIn, authentication));
listener.onResponse(new DelegatePkiAuthenticationResponse(tuple.v1().v1(), expiresIn, authentication));
}, listener::onFailure));
}, e -> {
logger.debug((Supplier<?>) () -> new ParameterizedMessage("Delegated x509Token [{}] could not be authenticated",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ protected void doExecute(Task task, OpenIdConnectAuthenticateRequest request,
tokenService.createOAuth2Tokens(authentication, originatingAuthentication, tokenMetadata, true,
ActionListener.wrap(tuple -> {
final TimeValue expiresIn = tokenService.getExpirationDelay();
listener.onResponse(new OpenIdConnectAuthenticateResponse(authentication, tuple.v1(), tuple.v2(), expiresIn));
listener.onResponse(new OpenIdConnectAuthenticateResponse(authentication, tuple.v1().v1(), tuple.v1().v2(),
expiresIn));
}, listener::onFailure));
}, e -> {
logger.debug(() -> new ParameterizedMessage("OpenIDConnectToken [{}] could not be authenticated", token), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ protected void doExecute(Task task, SamlAuthenticateRequest request, ActionListe
tokenMeta, true, ActionListener.wrap(tuple -> {
final TimeValue expiresIn = tokenService.getExpirationDelay();
listener.onResponse(
new SamlAuthenticateResponse(authentication, tuple.v1(), tuple.v2(), expiresIn));
new SamlAuthenticateResponse(authentication, tuple.v1().v1(), tuple.v1().v2(), expiresIn));
}, listener::onFailure));
}, e -> {
logger.debug(() -> new ParameterizedMessage("SamlToken [{}] could not be authenticated", saml), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ private void createToken(GrantType grantType, CreateTokenRequest request, Authen
ActionListener.wrap(tuple -> {
final String scope = getResponseScopeValue(request.getScope());
final String base64AuthenticateResponse = (grantType == GrantType.KERBEROS) ? extractOutToken() : null;
final CreateTokenResponse response = new CreateTokenResponse(tuple.v1(), tokenService.getExpirationDelay(), scope,
tuple.v2(), base64AuthenticateResponse, authentication);
final CreateTokenResponse response = new CreateTokenResponse(tuple.v1().v1(), tokenService.getExpirationDelay(), scope,
tuple.v1().v2(), base64AuthenticateResponse, authentication);
listener.onResponse(response);
}, listener::onFailure));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.security.action.token.CreateTokenRequest;
Expand All @@ -33,11 +32,9 @@ public TransportRefreshTokenAction(TransportService transportService, ActionFilt
protected void doExecute(Task task, CreateTokenRequest request, ActionListener<CreateTokenResponse> listener) {
tokenService.refreshToken(request.getRefreshToken(), ActionListener.wrap(tuple -> {
final String scope = getResponseScopeValue(request.getScope());
tokenService.authenticateToken(new SecureString(tuple.v1()), ActionListener.wrap(authentication -> {
listener.onResponse(new CreateTokenResponse(tuple.v1(), tokenService.getExpirationDelay(), scope, tuple.v2(), null,
authentication));
},
listener::onFailure));
final CreateTokenResponse response =
new CreateTokenResponse(tuple.v1().v1(), tokenService.getExpirationDelay(), scope, tuple.v1().v2(), null, tuple.v2());
listener.onResponse(response);
}, listener::onFailure));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ public TokenService(Settings settings, Clock clock, Client client, XPackLicenseS
* {@link #VERSION_TOKENS_INDEX_INTRODUCED} and to a specific security tokens index for later versions.
*/
public void createOAuth2Tokens(Authentication authentication, Authentication originatingClientAuth, Map<String, Object> metadata,
boolean includeRefreshToken, ActionListener<Tuple<String, String>> listener) {
boolean includeRefreshToken, ActionListener<Tuple<Tuple<String, String>, Authentication>> listener) {
BigPandaToo marked this conversation as resolved.
Show resolved Hide resolved
// the created token is compatible with the oldest node version in the cluster
final Version tokenVersion = getTokenVersionCompatibility();
// tokens moved to a separate index in newer versions
Expand All @@ -273,7 +273,7 @@ public void createOAuth2Tokens(Authentication authentication, Authentication ori
//public for testing
public void createOAuth2Tokens(String accessToken, String refreshToken, Authentication authentication,
Authentication originatingClientAuth,
Map<String, Object> metadata, ActionListener<Tuple<String, String>> listener) {
Map<String, Object> metadata, ActionListener<Tuple<Tuple<String, String>, Authentication>> listener) {
// the created token is compatible with the oldest node version in the cluster
final Version tokenVersion = getTokenVersionCompatibility();
// tokens moved to a separate index in newer versions
Expand Down Expand Up @@ -311,7 +311,7 @@ public void createOAuth2Tokens(String accessToken, String refreshToken, Authenti
*/
private void createOAuth2Tokens(String accessToken, String refreshToken, Version tokenVersion, SecurityIndexManager tokensIndex,
BigPandaToo marked this conversation as resolved.
Show resolved Hide resolved
Authentication authentication, Authentication originatingClientAuth, Map<String, Object> metadata,
ActionListener<Tuple<String, String>> listener) {
ActionListener<Tuple<Tuple<String, String>, Authentication>> listener) {
assert accessToken.length() == TOKEN_LENGTH : "We assume token ids have a fixed length for nodes of a certain version."
+ " When changing the token length, be careful that the inferences about its length still hold.";
ensureEnabled();
Expand Down Expand Up @@ -351,12 +351,13 @@ private void createOAuth2Tokens(String accessToken, String refreshToken, Version
final String versionedRefreshToken = refreshToken != null
? prependVersionAndEncodeRefreshToken(tokenVersion, refreshToken)
: null;
listener.onResponse(new Tuple<>(versionedAccessToken, versionedRefreshToken));
listener.onResponse(new Tuple<>(new Tuple<>(versionedAccessToken, versionedRefreshToken),
tokenAuth));
} else {
// prior versions of the refresh token are not version-prepended, as nodes on those
// versions don't expect it.
// Such nodes might exist in a mixed cluster during a rolling upgrade.
listener.onResponse(new Tuple<>(versionedAccessToken, refreshToken));
listener.onResponse(new Tuple<>(new Tuple<>(versionedAccessToken, refreshToken),tokenAuth));
}
} else {
listener.onFailure(traceLog("create token",
Expand Down Expand Up @@ -862,7 +863,7 @@ private void indexInvalidation(Collection<String> tokenIds, SecurityIndexManager
* @param listener The listener to call upon completion with a {@link Tuple} containing the
* serialized access token and serialized refresh token as these will be returned to the client
*/
public void refreshToken(String refreshToken, ActionListener<Tuple<String, String>> listener) {
public void refreshToken(String refreshToken, ActionListener<Tuple<Tuple<String, String>, Authentication>> listener) {
ensureEnabled();
final Instant refreshRequested = clock.instant();
final Iterator<TimeValue> backoff = DEFAULT_BACKOFF.iterator();
Expand Down Expand Up @@ -995,7 +996,7 @@ private void findTokenFromRefreshToken(String refreshToken, SecurityIndexManager
*/
private void innerRefresh(String refreshToken, String tokenDocId, Map<String, Object> source, long seqNo, long primaryTerm,
Authentication clientAuth, Iterator<TimeValue> backoff, Instant refreshRequested,
ActionListener<Tuple<String, String>> listener) {
ActionListener<Tuple<Tuple<String, String>, Authentication>> listener) {
logger.debug("Attempting to refresh token stored in token document [{}]", tokenDocId);
final Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("refresh token", tokenDocId, ex));
final Tuple<RefreshTokenStatus, Optional<ElasticsearchSecurityException>> checkRefreshResult;
Expand Down Expand Up @@ -1126,11 +1127,12 @@ public void onFailure(Exception e) {
* @param refreshTokenStatus The {@link RefreshTokenStatus} containing information about the superseding tokens as retrieved from the
* index
* @param tokensIndex the manager for the index where the tokens are stored
* @param listener The listener to call upon completion with a {@link Tuple} containing the
* serialized access token and serialized refresh token as these will be returned to the client
* @param listener The listener to call upon completion with a {@link Tuple} containing the Tuple of
* serialized access token and serialized refresh token and Authentication object as these will be returned
* to the client
*/
void decryptAndReturnSupersedingTokens(String refreshToken, RefreshTokenStatus refreshTokenStatus, SecurityIndexManager tokensIndex,
ActionListener<Tuple<String, String>> listener) {
ActionListener<Tuple<Tuple<String, String>, Authentication>> listener) {

final byte[] iv = Base64.getDecoder().decode(refreshTokenStatus.getIv());
final byte[] salt = Base64.getDecoder().decode(refreshTokenStatus.getSalt());
Expand Down Expand Up @@ -1166,8 +1168,9 @@ public void onResponse(GetResponse response) {
if (response.isExists()) {
try {
listener.onResponse(
new Tuple<>(prependVersionAndEncodeAccessToken(refreshTokenStatus.getVersion(), decryptedTokens[0]),
prependVersionAndEncodeRefreshToken(refreshTokenStatus.getVersion(), decryptedTokens[1])));
new Tuple<>(new Tuple<>(prependVersionAndEncodeAccessToken(refreshTokenStatus.getVersion(),
decryptedTokens[0]),
prependVersionAndEncodeRefreshToken(refreshTokenStatus.getVersion(), decryptedTokens[1])), null));
} catch (GeneralSecurityException | IOException e) {
logger.warn("Could not format stored superseding token values", e);
onFailure.accept(invalidGrantException("could not refresh the requested token"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,11 @@ public void testLogoutInvalidatesTokens() throws Exception {
final Authentication authentication = new Authentication(user, realmRef, null, null, Authentication.AuthenticationType.REALM,
tokenMetadata);

final PlainActionFuture<Tuple<String, String>> future = new PlainActionFuture<>();
final PlainActionFuture<Tuple<Tuple<String, String>, Authentication>> future = new PlainActionFuture<>();
final String userTokenId = UUIDs.randomBase64UUID();
final String refreshToken = UUIDs.randomBase64UUID();
tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, tokenMetadata, future);
final String accessToken = future.actionGet().v1();
final String accessToken = future.actionGet().v1().v1();
mockGetTokenFromId(tokenService, userTokenId, authentication, false, client);

final OpenIdConnectLogoutRequest request = new OpenIdConnectLogoutRequest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,9 @@ private Tuple<String, String> storeToken(String userTokenId, String refreshToken
Authentication authentication = new Authentication(new User("bob"),
new RealmRef("native", NativeRealmSettings.TYPE, "node01"), null);
final Map<String, Object> metadata = samlRealm.createTokenMetadata(nameId, session);
final PlainActionFuture<Tuple<String, String>> future = new PlainActionFuture<>();
final PlainActionFuture<Tuple<Tuple<String, String>, Authentication>> future = new PlainActionFuture<>();
tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, metadata, future);
return future.actionGet();
return future.actionGet().v1();
}

private Tuple<String, String> storeToken(SamlNameId nameId, String session) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,11 @@ public void testLogoutInvalidatesToken() throws Exception {
tokenMetadata);


final PlainActionFuture<Tuple<String, String>> future = new PlainActionFuture<>();
final PlainActionFuture<Tuple<Tuple<String, String>, Authentication>> future = new PlainActionFuture<>();
final String userTokenId = UUIDs.randomBase64UUID();
final String refreshToken = UUIDs.randomBase64UUID();
tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, tokenMetadata, future);
final String accessToken = future.actionGet().v1();
final String accessToken = future.actionGet().v1().v1();
mockGetTokenFromId(tokenService, userTokenId, authentication, false, client);

final SamlLogoutRequest request = new SamlLogoutRequest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1305,14 +1305,14 @@ public void testAuthenticateWithToken() throws Exception {
User user = new User("_username", "r1");
final AtomicBoolean completed = new AtomicBoolean(false);
final Authentication expected = new Authentication(user, new RealmRef("realm", "custom", "node"), null);
PlainActionFuture<Tuple<String, String>> tokenFuture = new PlainActionFuture<>();
PlainActionFuture<Tuple<Tuple<String, String>, Authentication>> tokenFuture = new PlainActionFuture<>();
final String userTokenId = UUIDs.randomBase64UUID();
final String refreshToken = UUIDs.randomBase64UUID();
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null);
tokenService.createOAuth2Tokens(userTokenId, refreshToken, expected, originatingAuth, Collections.emptyMap(), tokenFuture);
}
String token = tokenFuture.get().v1();
String token = tokenFuture.get().v1().v1();
when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE));
mockGetTokenFromId(tokenService, userTokenId, expected, false, client);
when(securityIndex.freeze()).thenReturn(securityIndex);
Expand Down Expand Up @@ -1393,14 +1393,14 @@ public void testExpiredToken() throws Exception {
when(securityIndex.indexExists()).thenReturn(true);
User user = new User("_username", "r1");
final Authentication expected = new Authentication(user, new RealmRef("realm", "custom", "node"), null);
PlainActionFuture<Tuple<String, String>> tokenFuture = new PlainActionFuture<>();
PlainActionFuture<Tuple<Tuple<String, String>, Authentication>> tokenFuture = new PlainActionFuture<>();
final String userTokenId = UUIDs.randomBase64UUID();
final String refreshToken = UUIDs.randomBase64UUID();
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null);
tokenService.createOAuth2Tokens(userTokenId, refreshToken, expected, originatingAuth, Collections.emptyMap(), tokenFuture);
}
String token = tokenFuture.get().v1();
String token = tokenFuture.get().v1().v1();
mockGetTokenFromId(tokenService, userTokenId, expected, true, client);
doAnswer(invocationOnMock -> {
((Runnable) invocationOnMock.getArguments()[1]).run();
Expand Down
Loading