diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index 507d34b103..ac86212e0e 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Concurrent; using System.Linq; using System.Security; using System.Threading; @@ -15,6 +16,9 @@ namespace Microsoft.Data.SqlClient /// public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider { + private static ConcurrentDictionary s_pcaMap + = new ConcurrentDictionary(); + private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient"; private static readonly string s_defaultScopeSuffix = "/.default"; private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name; private readonly SqlClientLogger _logger = new SqlClientLogger(); @@ -67,10 +71,10 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication) } #if NETSTANDARD - private Func parentActivityOrWindowFunc = null; + private Func _parentActivityOrWindowFunc = null; /// - public void SetParentActivityOrWindowFunc(Func parentActivityOrWindowFunc) => this.parentActivityOrWindowFunc = parentActivityOrWindowFunc; + public void SetParentActivityOrWindowFunc(Func parentActivityOrWindowFunc) => this._parentActivityOrWindowFunc = parentActivityOrWindowFunc; #endif #if NETFRAMEWORK @@ -108,51 +112,24 @@ public override Task AcquireTokenAsync(SqlAuthentication * * https://docs.microsoft.com/en-us/azure/active-directory/develop/scenario-desktop-app-registration#redirect-uris */ - string redirectURI = "https://login.microsoftonline.com/common/oauth2/nativeclient"; + string redirectUri = s_nativeClientRedirectUri; #if NETCOREAPP if (parameters.AuthenticationMethod != SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) { - redirectURI = "http://localhost"; - } -#endif - IPublicClientApplication app; - -#if NETSTANDARD - if (parentActivityOrWindowFunc != null) - { - app = PublicClientApplicationBuilder.Create(_applicationClientId) - .WithAuthority(parameters.Authority) - .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) - .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) - .WithRedirectUri(redirectURI) - .WithParentActivityOrWindow(parentActivityOrWindowFunc) - .Build(); + redirectUri = "http://localhost"; } #endif + PublicClientAppKey pcaKey = new PublicClientAppKey(parameters.Authority, redirectUri, _applicationClientId #if NETFRAMEWORK - if (_iWin32WindowFunc != null) - { - app = PublicClientApplicationBuilder.Create(_applicationClientId) - .WithAuthority(parameters.Authority) - .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) - .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) - .WithRedirectUri(redirectURI) - .WithParentActivityOrWindow(_iWin32WindowFunc) - .Build(); - } + , _iWin32WindowFunc #endif -#if !NETCOREAPP - else +#if NETSTANDARD + , _parentActivityOrWindowFunc #endif - { - app = PublicClientApplicationBuilder.Create(_applicationClientId) - .WithAuthority(parameters.Authority) - .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) - .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) - .WithRedirectUri(redirectURI) - .Build(); - } + ); + + IPublicClientApplication app = GetPublicClientAppInstance(pcaKey); if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated) { @@ -320,5 +297,168 @@ private class CustomWebUi : ICustomWebUi public Task AcquireAuthorizationCodeAsync(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken) => _acquireAuthorizationCodeAsyncCallback.Invoke(authorizationUri, redirectUri, cancellationToken); } + + private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey publicClientAppKey) + { + IPublicClientApplication clientApplicationInstance; + + if (s_pcaMap.ContainsKey(publicClientAppKey)) + { + s_pcaMap.TryGetValue(publicClientAppKey, out clientApplicationInstance); + } + else + { + clientApplicationInstance = CreateClientAppInstance(publicClientAppKey); + s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance); + } + return clientApplicationInstance; + } + + private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey) + { + IPublicClientApplication publicClientApplication; + +#if NETSTANDARD + if (_parentActivityOrWindowFunc != null) + { + publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId) + .WithAuthority(publicClientAppKey._authority) + .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) + .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) + .WithRedirectUri(publicClientAppKey._redirectUri) + .WithParentActivityOrWindow(_parentActivityOrWindowFunc) + .Build(); + } +#endif +#if NETFRAMEWORK + if (_iWin32WindowFunc != null) + { + publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId) + .WithAuthority(publicClientAppKey._authority) + .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) + .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) + .WithRedirectUri(publicClientAppKey._redirectUri) + .WithParentActivityOrWindow(_iWin32WindowFunc) + .Build(); + } +#endif +#if !NETCOREAPP + else +#endif + { + publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId) + .WithAuthority(publicClientAppKey._authority) + .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) + .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) + .WithRedirectUri(publicClientAppKey._redirectUri) + .Build(); + } + + return publicClientApplication; + } + + internal class PublicClientAppKey + { + public readonly string _authority; + public readonly string _redirectUri; + public readonly string _applicationClientId; +#if NETFRAMEWORK + public readonly Func _iWin32WindowFunc; +#endif +#if NETSTANDARD + public readonly Func _parentActivityOrWindowFunc; +#endif + private int _hashValue; + + public PublicClientAppKey(string authority, string redirectUri, string applicationClientId +#if NETFRAMEWORK + , Func iWin32WindowFunc +#endif +#if NETSTANDARD + , Func parentActivityOrWindowFunc +#endif + ) + { + _authority = authority; + _redirectUri = redirectUri; + _applicationClientId = applicationClientId; +#if NETFRAMEWORK + _iWin32WindowFunc = iWin32WindowFunc; +#endif +#if NETSTANDARD + _parentActivityOrWindowFunc = parentActivityOrWindowFunc; +#endif + } + + public override bool Equals(object obj) + { + if (obj == null) + { + return false; + } + + PublicClientAppKey pcaKey = obj as PublicClientAppKey; + return (string.CompareOrdinal(_authority, pcaKey._authority) == 0 + && string.CompareOrdinal(_redirectUri, pcaKey._redirectUri) == 0 + && string.CompareOrdinal(_applicationClientId, pcaKey._applicationClientId) == 0 +#if NETFRAMEWORK + && pcaKey._iWin32WindowFunc == _iWin32WindowFunc +#endif +#if NETSTANDARD + && pcaKey._parentActivityOrWindowFunc == _parentActivityOrWindowFunc +#endif + ); + } + + public override int GetHashCode() + { + return _hashValue; + } + + private void CalculateHashCode() + { + _hashValue = base.GetHashCode(); + + if (_authority != null) + { + unchecked + { + _hashValue = _hashValue * 17 + _authority.GetHashCode(); + } + } + if (_redirectUri != null) + { + unchecked + { + _hashValue = _hashValue * 17 + _redirectUri.GetHashCode(); + } + } + if (_applicationClientId != null) + { + unchecked + { + _hashValue = _hashValue * 17 + _applicationClientId.GetHashCode(); + } + } +#if NETFRAMEWORK + if (_iWin32WindowFunc != null) + { + unchecked + { + _hashValue = _hashValue * 17 + _iWin32WindowFunc.GetHashCode(); + } + } +#endif +#if NETSTANDARD + if (_parentActivityOrWindowFunc != null) + { + unchecked + { + _hashValue = _hashValue * 17 + _parentActivityOrWindowFunc.GetHashCode(); + } + } +#endif + } + } } }