Skip to content

Commit

Permalink
Fix #14408: [BUG] ChainedTokenCredential throws an exception when Vis…
Browse files Browse the repository at this point in the history
…ualStudioCodeCredential is present (#14882)

* Tests added

* Fix #14408: [BUG] ChainedTokenCredential throws an exception when VisualStudioCodeCredential is present
  • Loading branch information
AlexanderSher authored Sep 4, 2020
1 parent cb4f62a commit c8a4acb
Show file tree
Hide file tree
Showing 24 changed files with 3,685 additions and 196 deletions.
22 changes: 20 additions & 2 deletions sdk/identity/Azure.Identity/src/AuthenticationFailedException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,27 @@ public AuthenticationFailedException(string message, Exception innerException)
{
}

internal static AuthenticationFailedException CreateAggregateException(string message, IList<Exception> innerExceptions)
internal static AuthenticationFailedException CreateAggregateException(string message, IList<Exception> exceptions)
{
return new AuthenticationFailedException(message, new AggregateException("Multiple exceptions were encountered while attempting to authenticate.", innerExceptions.ToArray()));
// Build the credential unavailable message, this code is only reachable if all credentials throw AuthenticationFailedException
StringBuilder errorMsg = new StringBuilder(message);

bool allCredentialUnavailableException = true;
foreach (var exception in exceptions)
{
allCredentialUnavailableException &= exception is CredentialUnavailableException;
errorMsg.Append(Environment.NewLine).Append("- ").Append(exception.Message);
}

var innerException = exceptions.Count == 1
? exceptions[0]
: new AggregateException("Multiple exceptions were encountered while attempting to authenticate.", exceptions);

// If all credentials have thrown CredentialUnavailableException, throw CredentialUnavailableException,
// otherwise throw AuthenticationFailedException
return allCredentialUnavailableException
? new CredentialUnavailableException(errorMsg.ToString(), innerException)
: new AuthenticationFailedException(errorMsg.ToString(), innerException);
}
}
}
80 changes: 40 additions & 40 deletions sdk/identity/Azure.Identity/src/ChainedTokenCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
using Azure.Core;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core.Pipeline;

namespace Azure.Identity
{
Expand All @@ -22,6 +22,14 @@ public class ChainedTokenCredential : TokenCredential

private readonly TokenCredential[] _sources;

/// <summary>
/// Constructor for instrumenting in tests
/// </summary>
internal ChainedTokenCredential()
{
_sources = Array.Empty<TokenCredential>();
}

/// <summary>
/// Creates an instance with the specified <see cref="TokenCredential"/> sources.
/// </summary>
Expand Down Expand Up @@ -53,29 +61,7 @@ public ChainedTokenCredential(params TokenCredential[] sources)
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
/// <returns>The first <see cref="AccessToken"/> returned by the specified sources. Any credential which raises a <see cref="CredentialUnavailableException"/> will be skipped.</returns>
public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
{
List<Exception> exceptions = new List<Exception>();

for (int i = 0; i < _sources.Length; i++)
{
try
{
return _sources[i].GetToken(requestContext, cancellationToken);
}
catch (CredentialUnavailableException e)
{
exceptions.Add(e);
}
catch (Exception e) when (!(e is OperationCanceledException))
{
exceptions.Add(e);

throw AuthenticationFailedException.CreateAggregateException(AggregateCredentialFailedErrorMessage + e.Message, exceptions);
}
}

throw AuthenticationFailedException.CreateAggregateException(AggregateAllUnavailableErrorMessage, exceptions);
}
=> GetTokenImplAsync(false, requestContext, cancellationToken).EnsureCompleted();

/// <summary>
/// Sequentially calls <see cref="TokenCredential.GetToken"/> on all the specified sources, returning the first successfully obtained <see cref="AccessToken"/>. This method is called by Azure SDK clients. It isn't intended for use in application code.
Expand All @@ -84,28 +70,42 @@ public override AccessToken GetToken(TokenRequestContext requestContext, Cancell
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
/// <returns>The first <see cref="AccessToken"/> returned by the specified sources. Any credential which raises a <see cref="CredentialUnavailableException"/> will be skipped.</returns>
public override async ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
{
List<Exception> exceptions = new List<Exception>();
=> await GetTokenImplAsync(true, requestContext, cancellationToken).ConfigureAwait(false);

for (int i = 0; i < _sources.Length; i++)
private async ValueTask<AccessToken> GetTokenImplAsync(bool async, TokenRequestContext requestContext, CancellationToken cancellationToken)
{
var groupScopeHandler = new ScopeGroupHandler(default);
try
{
try
List<Exception> exceptions = new List<Exception>();
foreach (TokenCredential source in _sources)
{
return await _sources[i].GetTokenAsync(requestContext, cancellationToken).ConfigureAwait(false);
try
{
AccessToken token = async
? await source.GetTokenAsync(requestContext, cancellationToken).ConfigureAwait(false)
: source.GetToken(requestContext, cancellationToken);
groupScopeHandler.Dispose(default, default);
return token;
}
catch (AuthenticationFailedException e)
{
exceptions.Add(e);
}
catch (Exception e) when (!(e is OperationCanceledException))
{
exceptions.Add(e);
throw AuthenticationFailedException.CreateAggregateException(AggregateCredentialFailedErrorMessage + e.Message, exceptions);
}
}
catch (CredentialUnavailableException e)
{
exceptions.Add(e);
}
catch (Exception e) when (!(e is OperationCanceledException))
{
exceptions.Add(e);

throw AuthenticationFailedException.CreateAggregateException(AggregateCredentialFailedErrorMessage + e.Message, exceptions);
}
throw AuthenticationFailedException.CreateAggregateException(AggregateAllUnavailableErrorMessage, exceptions);
}
catch (Exception exception)
{
groupScopeHandler.Fail(default, default, exception);
throw;
}

throw AuthenticationFailedException.CreateAggregateException(AggregateAllUnavailableErrorMessage, exceptions);
}
}
}
4 changes: 2 additions & 2 deletions sdk/identity/Azure.Identity/src/CredentialDiagnosticScope.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ namespace Azure.Identity
private readonly TokenRequestContext _context;
private readonly IScopeHandler _scopeHandler;

public CredentialDiagnosticScope(string name, TokenRequestContext context, IScopeHandler scopeHandler)
public CredentialDiagnosticScope(ClientDiagnostics diagnostics, string name, TokenRequestContext context, IScopeHandler scopeHandler)
{
_name = name;
_scope = scopeHandler.CreateScope(name);
_scope = scopeHandler.CreateScope(diagnostics, name);
_context = context;
_scopeHandler = scopeHandler;
}
Expand Down
133 changes: 7 additions & 126 deletions sdk/identity/Azure.Identity/src/CredentialPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,18 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Azure.Core;
using Azure.Core.Diagnostics;
using Azure.Core.Pipeline;
using Microsoft.Identity.Client;

namespace Azure.Identity
{
internal class CredentialPipeline
internal class CredentialPipeline
{
private static readonly Lazy<CredentialPipeline> s_singleton = new Lazy<CredentialPipeline>(() => new CredentialPipeline(new TokenCredentialOptions()));

private readonly IScopeHandler _defaultScopeHandler;
private IScopeHandler _groupScopeHandler;
private static readonly IScopeHandler _defaultScopeHandler = new ScopeHandler();

private CredentialPipeline(TokenCredentialOptions options)
{
Expand All @@ -26,8 +22,6 @@ private CredentialPipeline(TokenCredentialOptions options)
HttpPipeline = HttpPipelineBuilder.Build(options, Array.Empty<HttpPipelinePolicy>(), Array.Empty<HttpPipelinePolicy>(), new CredentialResponseClassifier());

Diagnostics = new ClientDiagnostics(options);

_defaultScopeHandler = new ScopeHandler(Diagnostics);
}

public static CredentialPipeline GetInstance(TokenCredentialOptions options)
Expand All @@ -48,18 +42,18 @@ public IConfidentialClientApplication CreateMsalConfidentialClient(string tenant

public CredentialDiagnosticScope StartGetTokenScope(string fullyQualifiedMethod, TokenRequestContext context)
{
IScopeHandler scopeHandler = _groupScopeHandler ?? _defaultScopeHandler;
IScopeHandler scopeHandler = ScopeGroupHandler.Current ?? _defaultScopeHandler;

CredentialDiagnosticScope scope = new CredentialDiagnosticScope(fullyQualifiedMethod, context, scopeHandler);
CredentialDiagnosticScope scope = new CredentialDiagnosticScope(Diagnostics, fullyQualifiedMethod, context, scopeHandler);
scope.Start();
return scope;
}

public CredentialDiagnosticScope StartGetTokenScopeGroup(string fullyQualifiedMethod, TokenRequestContext context)
{
var scopeHandler = new ScopeGroupHandler(this, fullyQualifiedMethod);
var scopeHandler = new ScopeGroupHandler(fullyQualifiedMethod);

CredentialDiagnosticScope scope = new CredentialDiagnosticScope(fullyQualifiedMethod, context, scopeHandler);
CredentialDiagnosticScope scope = new CredentialDiagnosticScope(Diagnostics, fullyQualifiedMethod, context, scopeHandler);
scope.Start();
return scope;
}
Expand All @@ -74,123 +68,10 @@ public override bool IsRetriableResponse(HttpMessage message)

private class ScopeHandler : IScopeHandler
{
private readonly ClientDiagnostics _diagnostics;

public ScopeHandler(ClientDiagnostics diagnostics)
{
_diagnostics = diagnostics;
}

public DiagnosticScope CreateScope(string name) => _diagnostics.CreateScope(name);
public DiagnosticScope CreateScope(ClientDiagnostics diagnostics, string name) => diagnostics.CreateScope(name);
public void Start(string name, in DiagnosticScope scope) => scope.Start();
public void Dispose(string name, in DiagnosticScope scope) => scope.Dispose();
public void Fail(string name, in DiagnosticScope scope, Exception exception) => scope.Failed(exception);
}

private class ScopeGroupHandler : IScopeHandler
{
private readonly CredentialPipeline _pipeline;
private readonly string _groupName;
private Dictionary<string, (DateTime StartDateTime, Exception Exception)> _childScopes;

public ScopeGroupHandler(CredentialPipeline pipeline, string groupName)
{
_pipeline = pipeline;
_groupName = groupName;
}

public DiagnosticScope CreateScope(string name)
{
if (IsGroup(name))
{
_pipeline._groupScopeHandler = this;
return _pipeline.Diagnostics.CreateScope(name);
}

_childScopes ??= new Dictionary<string, (DateTime startDateTime, Exception exception)>();
_childScopes[name] = default;
return default;
}

public void Start(string name, in DiagnosticScope scope)
{
if (IsGroup(name))
{
scope.Start();
}
else
{
_childScopes[name] = (DateTime.UtcNow, default);
}
}

public void Dispose(string name, in DiagnosticScope scope)
{
if (!IsGroup(name))
{
return;
}

if (_childScopes != null)
{
var succeededScope = _childScopes.LastOrDefault(kvp => kvp.Value.Exception == default);
if (succeededScope.Key != default)
{
SucceedChildScope(succeededScope.Key, succeededScope.Value.StartDateTime);
}
}

scope.Dispose();
_pipeline._groupScopeHandler = default;
}

public void Fail(string name, in DiagnosticScope scope, Exception exception)
{
if (_childScopes == default)
{
scope.Failed(exception);
return;
}

if (IsGroup(name))
{
if (exception is OperationCanceledException)
{
var canceledScope = _childScopes.Last(kvp => kvp.Value.Exception == exception);
FailChildScope(canceledScope.Key, canceledScope.Value.StartDateTime, canceledScope.Value.Exception);
}
else
{
foreach (var childScope in _childScopes)
{
FailChildScope(childScope.Key, childScope.Value.StartDateTime, childScope.Value.Exception);
}
}

scope.Failed(exception);
}
else
{
_childScopes[name] = (_childScopes[name].StartDateTime, exception);
}
}

private void SucceedChildScope(string name, DateTime dateTime)
{
using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope(name);
scope.SetStartTime(dateTime);
scope.Start();
}

private void FailChildScope(string name, DateTime dateTime, Exception exception)
{
using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope(name);
scope.SetStartTime(dateTime);
scope.Start();
scope.Failed(exception);
}

private bool IsGroup(string name) => string.Equals(name, _groupName, StringComparison.Ordinal);
}
}
}
25 changes: 8 additions & 17 deletions sdk/identity/Azure.Identity/src/DefaultAzureCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace Azure.Identity
public class DefaultAzureCredential : TokenCredential
{
private const string DefaultExceptionMessage = "DefaultAzureCredential failed to retrieve a token from the included credentials.";
private const string UnhandledExceptionMessage = "DefaultAzureCredential authentication failed.";
private const string UnhandledExceptionMessage = "DefaultAzureCredential authentication failed due to an unhandled exception: ";
private static readonly TokenCredential[] s_defaultCredentialChain = GetDefaultAzureCredentialChain(new DefaultAzureCredentialFactory(null), new DefaultAzureCredentialOptions());

private readonly CredentialPipeline _pipeline;
Expand Down Expand Up @@ -143,7 +143,7 @@ private static async ValueTask<AccessToken> GetTokenFromCredentialAsync(TokenCre

private static async ValueTask<(AccessToken, TokenCredential)> GetTokenFromSourcesAsync(TokenCredential[] sources, TokenRequestContext requestContext, bool async, CancellationToken cancellationToken)
{
List<AuthenticationFailedException> exceptions = new List<AuthenticationFailedException>();
List<Exception> exceptions = new List<Exception>();

for (var i = 0; i < sources.Length && sources[i] != null; i++)
{
Expand All @@ -159,23 +159,14 @@ private static async ValueTask<AccessToken> GetTokenFromCredentialAsync(TokenCre
{
exceptions.Add(e);
}
catch (Exception e) when (!(e is OperationCanceledException))
{
exceptions.Add(e);
throw AuthenticationFailedException.CreateAggregateException(UnhandledExceptionMessage + e.Message, exceptions);
}
}

// Build the credential unavailable message, this code is only reachable if all credentials throw AuthenticationFailedException
StringBuilder errorMsg = new StringBuilder(DefaultExceptionMessage);

bool allCredentialUnavailableException = true;
foreach (AuthenticationFailedException ex in exceptions)
{
allCredentialUnavailableException &= ex is CredentialUnavailableException;
errorMsg.Append(Environment.NewLine).Append("- ").Append(ex.Message);
}

// If all credentials have thrown CredentialUnavailableException, throw CredentialUnavailableException,
// otherwise throw AuthenticationFailedException
throw allCredentialUnavailableException
? new CredentialUnavailableException(errorMsg.ToString())
: new AuthenticationFailedException(errorMsg.ToString());
throw AuthenticationFailedException.CreateAggregateException(DefaultExceptionMessage, exceptions);
}

private static TokenCredential[] GetDefaultAzureCredentialChain(DefaultAzureCredentialFactory factory, DefaultAzureCredentialOptions options)
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/Azure.Identity/src/IScopeHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Azure.Identity
{
internal interface IScopeHandler
{
DiagnosticScope CreateScope(string name);
DiagnosticScope CreateScope(ClientDiagnostics diagnostics, string name);
void Start(string name, in DiagnosticScope scope);
void Dispose(string name, in DiagnosticScope scope);
void Fail(string name, in DiagnosticScope scope, Exception exception);
Expand Down
Loading

0 comments on commit c8a4acb

Please sign in to comment.