diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index 507d34b103..2421511e62 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Concurrent; using System.Linq; using System.Security; using System.Threading; @@ -15,6 +16,14 @@ namespace Microsoft.Data.SqlClient /// public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider { + /// + /// This is a static cache instance meant to hold instances of "PublicClientApplication" mapping to information available in PublicClientAppKey. + /// The purpose of this cache is to allow re-use of Access Tokens fetched for a user interactively or with any other mode + /// to avoid interactive authentication request every-time, within application scope making use of MSAL's userTokenCache. + /// + private static ConcurrentDictionary s_pcaMap + = new ConcurrentDictionary(); + private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient"; private static readonly string s_defaultScopeSuffix = "/.default"; private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name; private readonly SqlClientLogger _logger = new SqlClientLogger(); @@ -67,10 +76,10 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication) } #if NETSTANDARD - private Func parentActivityOrWindowFunc = null; + private Func _parentActivityOrWindowFunc = null; /// - public void SetParentActivityOrWindowFunc(Func parentActivityOrWindowFunc) => this.parentActivityOrWindowFunc = parentActivityOrWindowFunc; + public void SetParentActivityOrWindowFunc(Func parentActivityOrWindowFunc) => this._parentActivityOrWindowFunc = parentActivityOrWindowFunc; #endif #if NETFRAMEWORK @@ -108,51 +117,24 @@ public override Task AcquireTokenAsync(SqlAuthentication * * https://docs.microsoft.com/en-us/azure/active-directory/develop/scenario-desktop-app-registration#redirect-uris */ - string redirectURI = "https://login.microsoftonline.com/common/oauth2/nativeclient"; + string redirectUri = s_nativeClientRedirectUri; #if NETCOREAPP if (parameters.AuthenticationMethod != SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) { - redirectURI = "http://localhost"; - } -#endif - IPublicClientApplication app; - -#if NETSTANDARD - if (parentActivityOrWindowFunc != null) - { - app = PublicClientApplicationBuilder.Create(_applicationClientId) - .WithAuthority(parameters.Authority) - .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) - .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) - .WithRedirectUri(redirectURI) - .WithParentActivityOrWindow(parentActivityOrWindowFunc) - .Build(); + redirectUri = "http://localhost"; } #endif + PublicClientAppKey pcaKey = new PublicClientAppKey(parameters.Authority, redirectUri, _applicationClientId #if NETFRAMEWORK - if (_iWin32WindowFunc != null) - { - app = PublicClientApplicationBuilder.Create(_applicationClientId) - .WithAuthority(parameters.Authority) - .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) - .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) - .WithRedirectUri(redirectURI) - .WithParentActivityOrWindow(_iWin32WindowFunc) - .Build(); - } + , _iWin32WindowFunc #endif -#if !NETCOREAPP - else +#if NETSTANDARD + , _parentActivityOrWindowFunc #endif - { - app = PublicClientApplicationBuilder.Create(_applicationClientId) - .WithAuthority(parameters.Authority) - .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) - .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) - .WithRedirectUri(redirectURI) - .Build(); - } + ); + + IPublicClientApplication app = GetPublicClientAppInstance(pcaKey); if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated) { @@ -185,6 +167,7 @@ public override Task AcquireTokenAsync(SqlAuthentication else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) { + // Fetch available accounts from 'app' instance System.Collections.Generic.IEnumerable accounts = await app.GetAccountsAsync(); IAccount account; if (!string.IsNullOrEmpty(parameters.UserId)) @@ -200,17 +183,23 @@ public override Task AcquireTokenAsync(SqlAuthentication { try { + // If 'account' is available in 'app', we use the same to acquire token silently. + // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn); } catch (MsalUiRequiredException) { + // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, + // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), + // or the user needs to perform two factor authentication. result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn); } } else { + // If no existing 'account' is found, we request user to sign in interactively. result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn); } @@ -320,5 +309,118 @@ private class CustomWebUi : ICustomWebUi public Task AcquireAuthorizationCodeAsync(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken) => _acquireAuthorizationCodeAsyncCallback.Invoke(authorizationUri, redirectUri, cancellationToken); } + + private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey publicClientAppKey) + { + if (!s_pcaMap.TryGetValue(publicClientAppKey, out IPublicClientApplication clientApplicationInstance)) + { + clientApplicationInstance = CreateClientAppInstance(publicClientAppKey); + s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance); + } + return clientApplicationInstance; + } + + private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey) + { + IPublicClientApplication publicClientApplication; + +#if NETSTANDARD + if (_parentActivityOrWindowFunc != null) + { + publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId) + .WithAuthority(publicClientAppKey._authority) + .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) + .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) + .WithRedirectUri(publicClientAppKey._redirectUri) + .WithParentActivityOrWindow(_parentActivityOrWindowFunc) + .Build(); + } +#endif +#if NETFRAMEWORK + if (_iWin32WindowFunc != null) + { + publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId) + .WithAuthority(publicClientAppKey._authority) + .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) + .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) + .WithRedirectUri(publicClientAppKey._redirectUri) + .WithParentActivityOrWindow(_iWin32WindowFunc) + .Build(); + } +#endif +#if !NETCOREAPP + else +#endif + { + publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId) + .WithAuthority(publicClientAppKey._authority) + .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) + .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) + .WithRedirectUri(publicClientAppKey._redirectUri) + .Build(); + } + + return publicClientApplication; + } + + internal class PublicClientAppKey + { + public readonly string _authority; + public readonly string _redirectUri; + public readonly string _applicationClientId; +#if NETFRAMEWORK + public readonly Func _iWin32WindowFunc; +#endif +#if NETSTANDARD + public readonly Func _parentActivityOrWindowFunc; +#endif + + public PublicClientAppKey(string authority, string redirectUri, string applicationClientId +#if NETFRAMEWORK + , Func iWin32WindowFunc +#endif +#if NETSTANDARD + , Func parentActivityOrWindowFunc +#endif + ) + { + _authority = authority; + _redirectUri = redirectUri; + _applicationClientId = applicationClientId; +#if NETFRAMEWORK + _iWin32WindowFunc = iWin32WindowFunc; +#endif +#if NETSTANDARD + _parentActivityOrWindowFunc = parentActivityOrWindowFunc; +#endif + } + + public override bool Equals(object obj) + { + if (obj != null && obj is PublicClientAppKey pcaKey) + { + return (string.CompareOrdinal(_authority, pcaKey._authority) == 0 + && string.CompareOrdinal(_redirectUri, pcaKey._redirectUri) == 0 + && string.CompareOrdinal(_applicationClientId, pcaKey._applicationClientId) == 0 +#if NETFRAMEWORK + && pcaKey._iWin32WindowFunc == _iWin32WindowFunc +#endif +#if NETSTANDARD + && pcaKey._parentActivityOrWindowFunc == _parentActivityOrWindowFunc +#endif + ); + } + return false; + } + + public override int GetHashCode() => Tuple.Create(_authority, _redirectUri, _applicationClientId +#if NETFRAMEWORK + , _iWin32WindowFunc +#endif +#if NETSTANDARD + , _parentActivityOrWindowFunc +#endif + ).GetHashCode(); + } } }