Skip to content

Commit

Permalink
Fix Azure#14408: [BUG] ChainedTokenCredential throws an exception whe…
Browse files Browse the repository at this point in the history
…n VisualStudioCodeCredential is present
  • Loading branch information
AlexanderSher committed Sep 3, 2020
1 parent bd6412f commit 80f0272
Show file tree
Hide file tree
Showing 24 changed files with 3,479 additions and 271 deletions.
22 changes: 20 additions & 2 deletions sdk/identity/Azure.Identity/src/AuthenticationFailedException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,27 @@ public AuthenticationFailedException(string message, Exception innerException)
{
}

internal static AuthenticationFailedException CreateAggregateException(string message, IList<Exception> innerExceptions)
internal static AuthenticationFailedException CreateAggregateException(string message, IList<Exception> exceptions)
{
return new AuthenticationFailedException(message, new AggregateException("Multiple exceptions were encountered while attempting to authenticate.", innerExceptions.ToArray()));
// Build the credential unavailable message, this code is only reachable if all credentials throw AuthenticationFailedException
StringBuilder errorMsg = new StringBuilder(message);

bool allCredentialUnavailableException = true;
foreach (var exception in exceptions)
{
allCredentialUnavailableException &= exception is CredentialUnavailableException;
errorMsg.Append(Environment.NewLine).Append("- ").Append(exception.Message);
}

var innerException = exceptions.Count == 1
? exceptions[0]
: new AggregateException("Multiple exceptions were encountered while attempting to authenticate.", exceptions);

// If all credentials have thrown CredentialUnavailableException, throw CredentialUnavailableException,
// otherwise throw AuthenticationFailedException
return allCredentialUnavailableException
? new CredentialUnavailableException(errorMsg.ToString(), innerException)
: new AuthenticationFailedException(errorMsg.ToString(), innerException);
}
}
}
94 changes: 30 additions & 64 deletions sdk/identity/Azure.Identity/src/ChainedTokenCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
using Azure.Core;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core.Pipeline;

namespace Azure.Identity
{
Expand All @@ -22,6 +22,9 @@ public class ChainedTokenCredential : TokenCredential

private readonly TokenCredential[] _sources;

/// <summary>
/// Constructor for instrumenting in tests
/// </summary>
internal ChainedTokenCredential()
{
_sources = Array.Empty<TokenCredential>();
Expand Down Expand Up @@ -58,29 +61,7 @@ public ChainedTokenCredential(params TokenCredential[] sources)
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
/// <returns>The first <see cref="AccessToken"/> returned by the specified sources. Any credential which raises a <see cref="CredentialUnavailableException"/> will be skipped.</returns>
public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
{
List<Exception> exceptions = new List<Exception>();

for (int i = 0; i < _sources.Length; i++)
{
try
{
return _sources[i].GetToken(requestContext, cancellationToken);
}
catch (CredentialUnavailableException e)
{
exceptions.Add(e);
}
catch (Exception e) when (!(e is OperationCanceledException))
{
exceptions.Add(e);

throw AuthenticationFailedException.CreateAggregateException(AggregateCredentialFailedErrorMessage + e.Message, exceptions);
}
}

throw AuthenticationFailedException.CreateAggregateException(AggregateAllUnavailableErrorMessage, exceptions);
}
=> GetTokenImplAsync(false, requestContext, cancellationToken).EnsureCompleted();

/// <summary>
/// Sequentially calls <see cref="TokenCredential.GetToken"/> on all the specified sources, returning the first successfully obtained <see cref="AccessToken"/>. This method is called by Azure SDK clients. It isn't intended for use in application code.
Expand All @@ -89,56 +70,41 @@ public override AccessToken GetToken(TokenRequestContext requestContext, Cancell
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
/// <returns>The first <see cref="AccessToken"/> returned by the specified sources. Any credential which raises a <see cref="CredentialUnavailableException"/> will be skipped.</returns>
public override async ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
{
List<Exception> exceptions = new List<Exception>();

for (int i = 0; i < _sources.Length; i++)
{
try
{
return await _sources[i].GetTokenAsync(requestContext, cancellationToken).ConfigureAwait(false);
}
catch (CredentialUnavailableException e)
{
exceptions.Add(e);
}
catch (Exception e) when (!(e is OperationCanceledException))
{
exceptions.Add(e);

throw AuthenticationFailedException.CreateAggregateException(AggregateCredentialFailedErrorMessage + e.Message, exceptions);
}
}

throw AuthenticationFailedException.CreateAggregateException(AggregateAllUnavailableErrorMessage, exceptions);
}
=> await GetTokenImplAsync(true, requestContext, cancellationToken).ConfigureAwait(false);

private async ValueTask<AccessToken> GetTokenImplAsync(bool async, TokenRequestContext requestContext, CancellationToken cancellationToken)
{
using CredentialDiagnosticScope scope = _pipeline.StartGetTokenScopeGroup("DefaultAzureCredential.GetToken", requestContext);

var groupScopeHandler = new ScopeGroupHandler(default);
try
{
using var asyncLock = await _credentialLock.GetLockOrValueAsync(async, cancellationToken).ConfigureAwait(false);

AccessToken token;
if (asyncLock.HasValue)
{
token = await GetTokenFromCredentialAsync(asyncLock.Value, requestContext, async, cancellationToken).ConfigureAwait(false);
}
else
List<Exception> exceptions = new List<Exception>();
foreach (TokenCredential source in _sources)
{
TokenCredential credential;
(token, credential) = await GetTokenFromSourcesAsync(_sources, requestContext, async, cancellationToken).ConfigureAwait(false);
_sources = default;
asyncLock.SetValue(credential);
try
{
AccessToken token = async
? await source.GetTokenAsync(requestContext, cancellationToken).ConfigureAwait(false)
: source.GetToken(requestContext, cancellationToken);
groupScopeHandler.Dispose(default, default);
return token;
}
catch (AuthenticationFailedException e)
{
exceptions.Add(e);
}
catch (Exception e) when (!(e is OperationCanceledException))
{
exceptions.Add(e);
throw AuthenticationFailedException.CreateAggregateException(AggregateCredentialFailedErrorMessage + e.Message, exceptions);
}
}

return scope.Succeeded(token);
throw AuthenticationFailedException.CreateAggregateException(AggregateAllUnavailableErrorMessage, exceptions);
}
catch (Exception e)
catch (Exception exception)
{
throw scope.FailWrapAndThrow(e);
groupScopeHandler.Fail(default, default, exception);
throw;
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions sdk/identity/Azure.Identity/src/CredentialDiagnosticScope.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ namespace Azure.Identity
private readonly TokenRequestContext _context;
private readonly IScopeHandler _scopeHandler;

public CredentialDiagnosticScope(string name, TokenRequestContext context, IScopeHandler scopeHandler)
public CredentialDiagnosticScope(ClientDiagnostics diagnostics, string name, TokenRequestContext context, IScopeHandler scopeHandler)
{
_name = name;
_scope = scopeHandler.CreateScope(name);
_scope = scopeHandler.CreateScope(diagnostics, name);
_context = context;
_scopeHandler = scopeHandler;
}
Expand Down
133 changes: 7 additions & 126 deletions sdk/identity/Azure.Identity/src/CredentialPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,18 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Azure.Core;
using Azure.Core.Diagnostics;
using Azure.Core.Pipeline;
using Microsoft.Identity.Client;

namespace Azure.Identity
{
internal class CredentialPipeline
internal class CredentialPipeline
{
private static readonly Lazy<CredentialPipeline> s_singleton = new Lazy<CredentialPipeline>(() => new CredentialPipeline(new TokenCredentialOptions()));

private readonly IScopeHandler _defaultScopeHandler;
private IScopeHandler _groupScopeHandler;
private static readonly IScopeHandler _defaultScopeHandler = new ScopeHandler();

private CredentialPipeline(TokenCredentialOptions options)
{
Expand All @@ -26,8 +22,6 @@ private CredentialPipeline(TokenCredentialOptions options)
HttpPipeline = HttpPipelineBuilder.Build(options, Array.Empty<HttpPipelinePolicy>(), Array.Empty<HttpPipelinePolicy>(), new CredentialResponseClassifier());

Diagnostics = new ClientDiagnostics(options);

_defaultScopeHandler = new ScopeHandler(Diagnostics);
}

public static CredentialPipeline GetInstance(TokenCredentialOptions options)
Expand All @@ -48,18 +42,18 @@ public IConfidentialClientApplication CreateMsalConfidentialClient(string tenant

public CredentialDiagnosticScope StartGetTokenScope(string fullyQualifiedMethod, TokenRequestContext context)
{
IScopeHandler scopeHandler = _groupScopeHandler ?? _defaultScopeHandler;
IScopeHandler scopeHandler = ScopeGroupHandler.Current ?? _defaultScopeHandler;

CredentialDiagnosticScope scope = new CredentialDiagnosticScope(fullyQualifiedMethod, context, scopeHandler);
CredentialDiagnosticScope scope = new CredentialDiagnosticScope(Diagnostics, fullyQualifiedMethod, context, scopeHandler);
scope.Start();
return scope;
}

public CredentialDiagnosticScope StartGetTokenScopeGroup(string fullyQualifiedMethod, TokenRequestContext context)
{
var scopeHandler = new ScopeGroupHandler(this, fullyQualifiedMethod);
var scopeHandler = new ScopeGroupHandler(fullyQualifiedMethod);

CredentialDiagnosticScope scope = new CredentialDiagnosticScope(fullyQualifiedMethod, context, scopeHandler);
CredentialDiagnosticScope scope = new CredentialDiagnosticScope(Diagnostics, fullyQualifiedMethod, context, scopeHandler);
scope.Start();
return scope;
}
Expand All @@ -74,123 +68,10 @@ public override bool IsRetriableResponse(HttpMessage message)

private class ScopeHandler : IScopeHandler
{
private readonly ClientDiagnostics _diagnostics;

public ScopeHandler(ClientDiagnostics diagnostics)
{
_diagnostics = diagnostics;
}

public DiagnosticScope CreateScope(string name) => _diagnostics.CreateScope(name);
public DiagnosticScope CreateScope(ClientDiagnostics diagnostics, string name) => diagnostics.CreateScope(name);
public void Start(string name, in DiagnosticScope scope) => scope.Start();
public void Dispose(string name, in DiagnosticScope scope) => scope.Dispose();
public void Fail(string name, in DiagnosticScope scope, Exception exception) => scope.Failed(exception);
}

private class ScopeGroupHandler : IScopeHandler
{
private readonly CredentialPipeline _pipeline;
private readonly string _groupName;
private Dictionary<string, (DateTime StartDateTime, Exception Exception)> _childScopes;

public ScopeGroupHandler(CredentialPipeline pipeline, string groupName)
{
_pipeline = pipeline;
_groupName = groupName;
}

public DiagnosticScope CreateScope(string name)
{
if (IsGroup(name))
{
_pipeline._groupScopeHandler = this;
return _pipeline.Diagnostics.CreateScope(name);
}

_childScopes ??= new Dictionary<string, (DateTime startDateTime, Exception exception)>();
_childScopes[name] = default;
return default;
}

public void Start(string name, in DiagnosticScope scope)
{
if (IsGroup(name))
{
scope.Start();
}
else
{
_childScopes[name] = (DateTime.UtcNow, default);
}
}

public void Dispose(string name, in DiagnosticScope scope)
{
if (!IsGroup(name))
{
return;
}

if (_childScopes != null)
{
var succeededScope = _childScopes.LastOrDefault(kvp => kvp.Value.Exception == default);
if (succeededScope.Key != default)
{
SucceedChildScope(succeededScope.Key, succeededScope.Value.StartDateTime);
}
}

scope.Dispose();
_pipeline._groupScopeHandler = default;
}

public void Fail(string name, in DiagnosticScope scope, Exception exception)
{
if (_childScopes == default)
{
scope.Failed(exception);
return;
}

if (IsGroup(name))
{
if (exception is OperationCanceledException)
{
var canceledScope = _childScopes.Last(kvp => kvp.Value.Exception == exception);
FailChildScope(canceledScope.Key, canceledScope.Value.StartDateTime, canceledScope.Value.Exception);
}
else
{
foreach (var childScope in _childScopes)
{
FailChildScope(childScope.Key, childScope.Value.StartDateTime, childScope.Value.Exception);
}
}

scope.Failed(exception);
}
else
{
_childScopes[name] = (_childScopes[name].StartDateTime, exception);
}
}

private void SucceedChildScope(string name, DateTime dateTime)
{
using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope(name);
scope.SetStartTime(dateTime);
scope.Start();
}

private void FailChildScope(string name, DateTime dateTime, Exception exception)
{
using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope(name);
scope.SetStartTime(dateTime);
scope.Start();
scope.Failed(exception);
}

private bool IsGroup(string name) => string.Equals(name, _groupName, StringComparison.Ordinal);
}
}
}
Loading

0 comments on commit 80f0272

Please sign in to comment.