diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs
index 507d34b103..2421511e62 100644
--- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs
@@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.
using System;
+using System.Collections.Concurrent;
using System.Linq;
using System.Security;
using System.Threading;
@@ -15,6 +16,14 @@ namespace Microsoft.Data.SqlClient
///
public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider
{
+ ///
+ /// This is a static cache instance meant to hold instances of "PublicClientApplication" mapping to information available in PublicClientAppKey.
+ /// The purpose of this cache is to allow re-use of Access Tokens fetched for a user interactively or with any other mode
+ /// to avoid interactive authentication request every-time, within application scope making use of MSAL's userTokenCache.
+ ///
+ private static ConcurrentDictionary s_pcaMap
+ = new ConcurrentDictionary();
+ private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
private static readonly string s_defaultScopeSuffix = "/.default";
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
private readonly SqlClientLogger _logger = new SqlClientLogger();
@@ -67,10 +76,10 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication)
}
#if NETSTANDARD
- private Func parentActivityOrWindowFunc = null;
+ private Func _parentActivityOrWindowFunc = null;
///
- public void SetParentActivityOrWindowFunc(Func parentActivityOrWindowFunc) => this.parentActivityOrWindowFunc = parentActivityOrWindowFunc;
+ public void SetParentActivityOrWindowFunc(Func parentActivityOrWindowFunc) => this._parentActivityOrWindowFunc = parentActivityOrWindowFunc;
#endif
#if NETFRAMEWORK
@@ -108,51 +117,24 @@ public override Task AcquireTokenAsync(SqlAuthentication
*
* https://docs.microsoft.com/en-us/azure/active-directory/develop/scenario-desktop-app-registration#redirect-uris
*/
- string redirectURI = "https://login.microsoftonline.com/common/oauth2/nativeclient";
+ string redirectUri = s_nativeClientRedirectUri;
#if NETCOREAPP
if (parameters.AuthenticationMethod != SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
{
- redirectURI = "http://localhost";
- }
-#endif
- IPublicClientApplication app;
-
-#if NETSTANDARD
- if (parentActivityOrWindowFunc != null)
- {
- app = PublicClientApplicationBuilder.Create(_applicationClientId)
- .WithAuthority(parameters.Authority)
- .WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
- .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
- .WithRedirectUri(redirectURI)
- .WithParentActivityOrWindow(parentActivityOrWindowFunc)
- .Build();
+ redirectUri = "http://localhost";
}
#endif
+ PublicClientAppKey pcaKey = new PublicClientAppKey(parameters.Authority, redirectUri, _applicationClientId
#if NETFRAMEWORK
- if (_iWin32WindowFunc != null)
- {
- app = PublicClientApplicationBuilder.Create(_applicationClientId)
- .WithAuthority(parameters.Authority)
- .WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
- .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
- .WithRedirectUri(redirectURI)
- .WithParentActivityOrWindow(_iWin32WindowFunc)
- .Build();
- }
+ , _iWin32WindowFunc
#endif
-#if !NETCOREAPP
- else
+#if NETSTANDARD
+ , _parentActivityOrWindowFunc
#endif
- {
- app = PublicClientApplicationBuilder.Create(_applicationClientId)
- .WithAuthority(parameters.Authority)
- .WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
- .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
- .WithRedirectUri(redirectURI)
- .Build();
- }
+ );
+
+ IPublicClientApplication app = GetPublicClientAppInstance(pcaKey);
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
@@ -185,6 +167,7 @@ public override Task AcquireTokenAsync(SqlAuthentication
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
{
+ // Fetch available accounts from 'app' instance
System.Collections.Generic.IEnumerable accounts = await app.GetAccountsAsync();
IAccount account;
if (!string.IsNullOrEmpty(parameters.UserId))
@@ -200,17 +183,23 @@ public override Task AcquireTokenAsync(SqlAuthentication
{
try
{
+ // If 'account' is available in 'app', we use the same to acquire token silently.
+ // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync();
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
catch (MsalUiRequiredException)
{
+ // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
+ // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
+ // or the user needs to perform two factor authentication.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
}
else
{
+ // If no existing 'account' is found, we request user to sign in interactively.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
@@ -320,5 +309,118 @@ private class CustomWebUi : ICustomWebUi
public Task AcquireAuthorizationCodeAsync(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken)
=> _acquireAuthorizationCodeAsyncCallback.Invoke(authorizationUri, redirectUri, cancellationToken);
}
+
+ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey publicClientAppKey)
+ {
+ if (!s_pcaMap.TryGetValue(publicClientAppKey, out IPublicClientApplication clientApplicationInstance))
+ {
+ clientApplicationInstance = CreateClientAppInstance(publicClientAppKey);
+ s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance);
+ }
+ return clientApplicationInstance;
+ }
+
+ private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey)
+ {
+ IPublicClientApplication publicClientApplication;
+
+#if NETSTANDARD
+ if (_parentActivityOrWindowFunc != null)
+ {
+ publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
+ .WithAuthority(publicClientAppKey._authority)
+ .WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
+ .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
+ .WithRedirectUri(publicClientAppKey._redirectUri)
+ .WithParentActivityOrWindow(_parentActivityOrWindowFunc)
+ .Build();
+ }
+#endif
+#if NETFRAMEWORK
+ if (_iWin32WindowFunc != null)
+ {
+ publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
+ .WithAuthority(publicClientAppKey._authority)
+ .WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
+ .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
+ .WithRedirectUri(publicClientAppKey._redirectUri)
+ .WithParentActivityOrWindow(_iWin32WindowFunc)
+ .Build();
+ }
+#endif
+#if !NETCOREAPP
+ else
+#endif
+ {
+ publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
+ .WithAuthority(publicClientAppKey._authority)
+ .WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
+ .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
+ .WithRedirectUri(publicClientAppKey._redirectUri)
+ .Build();
+ }
+
+ return publicClientApplication;
+ }
+
+ internal class PublicClientAppKey
+ {
+ public readonly string _authority;
+ public readonly string _redirectUri;
+ public readonly string _applicationClientId;
+#if NETFRAMEWORK
+ public readonly Func _iWin32WindowFunc;
+#endif
+#if NETSTANDARD
+ public readonly Func _parentActivityOrWindowFunc;
+#endif
+
+ public PublicClientAppKey(string authority, string redirectUri, string applicationClientId
+#if NETFRAMEWORK
+ , Func iWin32WindowFunc
+#endif
+#if NETSTANDARD
+ , Func parentActivityOrWindowFunc
+#endif
+ )
+ {
+ _authority = authority;
+ _redirectUri = redirectUri;
+ _applicationClientId = applicationClientId;
+#if NETFRAMEWORK
+ _iWin32WindowFunc = iWin32WindowFunc;
+#endif
+#if NETSTANDARD
+ _parentActivityOrWindowFunc = parentActivityOrWindowFunc;
+#endif
+ }
+
+ public override bool Equals(object obj)
+ {
+ if (obj != null && obj is PublicClientAppKey pcaKey)
+ {
+ return (string.CompareOrdinal(_authority, pcaKey._authority) == 0
+ && string.CompareOrdinal(_redirectUri, pcaKey._redirectUri) == 0
+ && string.CompareOrdinal(_applicationClientId, pcaKey._applicationClientId) == 0
+#if NETFRAMEWORK
+ && pcaKey._iWin32WindowFunc == _iWin32WindowFunc
+#endif
+#if NETSTANDARD
+ && pcaKey._parentActivityOrWindowFunc == _parentActivityOrWindowFunc
+#endif
+ );
+ }
+ return false;
+ }
+
+ public override int GetHashCode() => Tuple.Create(_authority, _redirectUri, _applicationClientId
+#if NETFRAMEWORK
+ , _iWin32WindowFunc
+#endif
+#if NETSTANDARD
+ , _parentActivityOrWindowFunc
+#endif
+ ).GetHashCode();
+ }
}
}