diff --git a/src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointProvider.cs b/src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointProvider.cs index dc1d761ad..1728b143d 100644 --- a/src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointProvider.cs +++ b/src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointProvider.cs @@ -95,9 +95,9 @@ public string GetServerEndpoint(string hubName) public IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId) { - if (_accessKey is AadAccessKey aadAccessKey) + if (_accessKey is AccessKeyForMicrosoftEntra key) { - return new AadTokenProvider(aadAccessKey); + return new MicrosoftEntraTokenProvider(key); } else if (_accessKey is not null) { diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/AadTokenProvider.cs b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntraTokenProvider.cs similarity index 54% rename from src/Microsoft.Azure.SignalR.Common/Auth/AadTokenProvider.cs rename to src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntraTokenProvider.cs index 68e42c44a..52f799852 100644 --- a/src/Microsoft.Azure.SignalR.Common/Auth/AadTokenProvider.cs +++ b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntraTokenProvider.cs @@ -6,15 +6,15 @@ namespace Microsoft.Azure.SignalR { - internal class AadTokenProvider : IAccessTokenProvider + internal class MicrosoftEntraTokenProvider : IAccessTokenProvider { - private readonly AadAccessKey _accessKey; + private readonly AccessKeyForMicrosoftEntra _accessKey; - public AadTokenProvider(AadAccessKey accessKey) + public MicrosoftEntraTokenProvider(AccessKeyForMicrosoftEntra accessKey) { _accessKey = accessKey ?? throw new ArgumentNullException(nameof(accessKey)); } - public Task ProvideAsync() => _accessKey.GenerateAadTokenAsync(); + public Task ProvideAsync() => _accessKey.GetMicrosoftEntraTokenAsync(); } } diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeyForMicrosoftEntra.cs similarity index 73% rename from src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs rename to src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeyForMicrosoftEntra.cs index 411c39a73..16d8cd5f1 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeyForMicrosoftEntra.cs @@ -17,27 +17,27 @@ namespace Microsoft.Azure.SignalR { - internal class AadAccessKey : AccessKey + internal class AccessKeyForMicrosoftEntra : AccessKey { - internal const int AuthorizeIntervalInMinute = 55; + internal const int GetAccessKeyIntervalInMinute = 55; - internal const int AuthorizeMaxRetryTimes = 3; + internal const int GetAccessKeyMaxRetryTimes = 3; - internal const int AuthorizeRetryIntervalInSec = 3; + internal const int GetAccessKeyRetryIntervalInSec = 3; - internal const int GetTokenMaxRetryTimes = 3; + internal const int GetMicrosoftEntraTokenMaxRetryTimes = 3; - internal static readonly TimeSpan AuthorizeTimeout = TimeSpan.FromSeconds(100); + internal static readonly TimeSpan GetAccessKeyTimeout = TimeSpan.FromSeconds(100); private const string DefaultScope = "https://signalr.azure.com/.default"; - private static readonly TimeSpan AuthorizeInterval = TimeSpan.FromMinutes(AuthorizeIntervalInMinute); - private static readonly TokenRequestContext DefaultRequestContext = new TokenRequestContext(new string[] { DefaultScope }); - private static readonly TimeSpan AuthorizeIntervalWhenFailed = TimeSpan.FromMinutes(5); + private static readonly TimeSpan GetAccessKeyInterval = TimeSpan.FromMinutes(GetAccessKeyIntervalInMinute); + + private static readonly TimeSpan GetAccessKeyIntervalWhenUnauthorized = TimeSpan.FromMinutes(5); - private static readonly TimeSpan AuthorizeRetryInterval = TimeSpan.FromSeconds(AuthorizeRetryIntervalInSec); + private static readonly TimeSpan GetAccessKeyRetryInterval = TimeSpan.FromSeconds(GetAccessKeyRetryIntervalInSec); private readonly TaskCompletionSource _initializedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); @@ -58,23 +58,23 @@ private set public TokenCredential TokenCredential { get; } - internal string AuthorizeUrl { get; } + internal string GetAccessKeyUrl { get; } - internal bool HasExpired => DateTime.UtcNow - _lastUpdatedTime > TimeSpan.FromMinutes(AuthorizeIntervalInMinute * 2); + internal bool HasExpired => DateTime.UtcNow - _lastUpdatedTime > TimeSpan.FromMinutes(GetAccessKeyIntervalInMinute * 2); private Task InitializedTask => _initializedTcs.Task; - public AadAccessKey(Uri endpoint, TokenCredential credential, Uri serverEndpoint = null) : base(endpoint) + public AccessKeyForMicrosoftEntra(Uri endpoint, TokenCredential credential, Uri serverEndpoint = null) : base(endpoint) { var authorizeUri = (serverEndpoint ?? endpoint).Append("/api/v1/auth/accessKey"); - AuthorizeUrl = authorizeUri.AbsoluteUri; + GetAccessKeyUrl = authorizeUri.AbsoluteUri; TokenCredential = credential; } - public virtual async Task GenerateAadTokenAsync(CancellationToken ctoken = default) + public virtual async Task GetMicrosoftEntraTokenAsync(CancellationToken ctoken = default) { Exception latest = null; - for (var i = 0; i < GetTokenMaxRetryTimes; i++) + for (var i = 0; i < GetMicrosoftEntraTokenMaxRetryTimes; i++) { try { @@ -125,28 +125,24 @@ internal void UpdateAccessKey(string kid, string accessKey) internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default) { var delta = DateTime.UtcNow - _lastUpdatedTime; - if (Authorized && delta < AuthorizeInterval) + if (Authorized && delta < GetAccessKeyInterval) { return; } - else if (!Authorized && delta < AuthorizeIntervalWhenFailed) + else if (!Authorized && delta < GetAccessKeyIntervalWhenUnauthorized) { return; } - await AuthorizeWithRetryAsync(ctoken); - } - private async Task AuthorizeWithRetryAsync(CancellationToken ctoken = default) - { Exception latest = null; - for (var i = 0; i < AuthorizeMaxRetryTimes; i++) + for (var i = 0; i < GetAccessKeyMaxRetryTimes; i++) { - var source = new CancellationTokenSource(AuthorizeTimeout); + var source = new CancellationTokenSource(GetAccessKeyTimeout); var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(source.Token, ctoken); try { - var token = await GenerateAadTokenAsync(linkedSource.Token); - await AuthorizeWithTokenAsync(token, linkedSource.Token); + var token = await GetMicrosoftEntraTokenAsync(linkedSource.Token); + await GetAccessKeyInternalAsync(token, linkedSource.Token); return; } catch (OperationCanceledException e) @@ -159,7 +155,7 @@ private async Task AuthorizeWithRetryAsync(CancellationToken ctoken = default) latest = e; try { - await Task.Delay(AuthorizeRetryInterval, ctoken); + await Task.Delay(GetAccessKeyRetryInterval, ctoken); } catch (OperationCanceledException) { @@ -172,9 +168,9 @@ private async Task AuthorizeWithRetryAsync(CancellationToken ctoken = default) throw latest; } - private async Task AuthorizeWithTokenAsync(string accessToken, CancellationToken ctoken = default) + private async Task GetAccessKeyInternalAsync(string accessToken, CancellationToken ctoken = default) { - var api = new RestApiEndpoint(AuthorizeUrl, accessToken); + var api = new RestApiEndpoint(GetAccessKeyUrl, accessToken); await new RestClient().SendAsync( api, diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.Log.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.Log.cs index 8e1028b99..edfc930bb 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.Log.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.Log.cs @@ -11,7 +11,7 @@ internal partial class AccessKeySynchronizer private static class Log { private static readonly Action _failedAuthorize = - LoggerMessage.Define(LogLevel.Warning, new EventId(2, "FailedAuthorizeAccessKey"), "Failed in authorizing AccessKey for '{endpoint}', will retry in " + AadAccessKey.AuthorizeRetryIntervalInSec + " seconds"); + LoggerMessage.Define(LogLevel.Warning, new EventId(2, "FailedAuthorizeAccessKey"), "Failed in authorizing AccessKey for '{endpoint}', will retry in " + SignalR.AccessKeyForMicrosoftEntra.GetAccessKeyRetryIntervalInSec + " seconds"); private static readonly Action _succeedAuthorize = LoggerMessage.Define(LogLevel.Information, new EventId(3, "SucceedAuthorizeAccessKey"), "Succeed in authorizing AccessKey for '{endpoint}'"); diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs index 8bedc821e..93d71370b 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs @@ -19,19 +19,16 @@ internal partial class AccessKeySynchronizer : IAccessKeySynchronizer, IDisposab private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(1)); - public AccessKeySynchronizer( - ILoggerFactory loggerFactory - ) : this(loggerFactory, true) + internal IEnumerable AccessKeyForMicrosoftEntraList => _endpoints.Select(e => e.Key.AccessKey).OfType(); + + public AccessKeySynchronizer(ILoggerFactory loggerFactory) : this(loggerFactory, true) { } /// - /// For test only. + /// Test only. /// - internal AccessKeySynchronizer( - ILoggerFactory loggerFactory, - bool start - ) + internal AccessKeySynchronizer(ILoggerFactory loggerFactory, bool start) { if (start) { @@ -42,9 +39,9 @@ bool start public void AddServiceEndpoint(ServiceEndpoint endpoint) { - if (endpoint.AccessKey is AadAccessKey aadKey) + if (endpoint.AccessKey is AccessKeyForMicrosoftEntra key) { - _ = UpdateAccessKeyAsync(aadKey); + _ = UpdateAccessKeyAsync(key); } _endpoints.TryAdd(endpoint, null); } @@ -64,8 +61,6 @@ public void UpdateServiceEndpoints(IEnumerable endpoints) internal int ServiceEndpointsCount() => _endpoints.Count; - internal IEnumerable FilterAadAccessKeys() => _endpoints.Select(e => e.Key.AccessKey).OfType(); - private async Task UpdateAccessKeyAsync() { using (_timer) @@ -74,7 +69,7 @@ private async Task UpdateAccessKeyAsync() while (await _timer) { - foreach (var key in FilterAadAccessKeys()) + foreach (var key in AccessKeyForMicrosoftEntraList) { _ = UpdateAccessKeyAsync(key); } @@ -82,17 +77,17 @@ private async Task UpdateAccessKeyAsync() } } - private async Task UpdateAccessKeyAsync(AadAccessKey key) + private async Task UpdateAccessKeyAsync(AccessKeyForMicrosoftEntra key) { - var logger = _factory.CreateLogger(); + var logger = _factory.CreateLogger(); try { await key.UpdateAccessKeyAsync(); - Log.SucceedToAuthorizeAccessKey(logger, key.AuthorizeUrl); + Log.SucceedToAuthorizeAccessKey(logger, key.GetAccessKeyUrl); } catch (Exception e) { - Log.FailedToAuthorizeAccessKey(logger, key.AuthorizeUrl, e); + Log.FailedToAuthorizeAccessKey(logger, key.GetAccessKeyUrl, e); } } diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs index d9588ecb2..3d8b18a10 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs @@ -86,7 +86,7 @@ internal AccessKey AccessKey { lock (_lock) { - _accessKey ??= new AadAccessKey(_serviceEndpoint, _tokenCredential, ServerEndpoint); + _accessKey ??= new AccessKeyForMicrosoftEntra(_serviceEndpoint, _tokenCredential, ServerEndpoint); } } return _accessKey; diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs index 476f22025..e8068059b 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs @@ -162,10 +162,10 @@ public async Task StartAsync(string target = null) TimerAwaitable syncTimer = null; try { - if (HubEndpoint != null && HubEndpoint.AccessKey is AadAccessKey aadKey) + if (HubEndpoint != null && HubEndpoint.AccessKey is AccessKeyForMicrosoftEntra key) { syncTimer = new TimerAwaitable(TimeSpan.Zero, DefaultSyncAzureIdentityInterval); - _ = UpdateAzureIdentityAsync(aadKey, syncTimer); + _ = UpdateAzureIdentityAsync(key, syncTimer); } await ProcessIncomingAsync(connection); } @@ -330,7 +330,7 @@ private Task OnEventMessageAsync(ServiceEventMessage message) private Task OnAccessKeyMessageAsync(AccessKeyResponseMessage keyMessage) { - if (HubEndpoint.AccessKey is AadAccessKey key) + if (HubEndpoint.AccessKey is AccessKeyForMicrosoftEntra key) { if (string.IsNullOrEmpty(keyMessage.ErrorType)) { @@ -474,7 +474,7 @@ private async Task ReceiveHandshakeResponseAsync(PipeReader input, Cancell } } - private async Task UpdateAzureIdentityAsync(AadAccessKey key, TimerAwaitable timer) + private async Task UpdateAzureIdentityAsync(AccessKeyForMicrosoftEntra key, TimerAwaitable timer) { using (timer) { @@ -486,12 +486,12 @@ private async Task UpdateAzureIdentityAsync(AadAccessKey key, TimerAwaitable tim } } - private async Task SendAccessKeyRequestMessageAsync(AadAccessKey key) + private async Task SendAccessKeyRequestMessageAsync(AccessKeyForMicrosoftEntra key) { try { - var source = new CancellationTokenSource(AadAccessKey.AuthorizeTimeout); - var token = await key.GenerateAadTokenAsync(source.Token); + var source = new CancellationTokenSource(AccessKeyForMicrosoftEntra.GetAccessKeyTimeout); + var token = await key.GetMicrosoftEntraTokenAsync(source.Token); var message = new AccessKeyRequestMessage(token); await SafeWriteAsync(message); } diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs index c9c75f32c..d0bf24b8e 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs @@ -169,11 +169,11 @@ private static AccessKey BuildAzureADAccessKey(Uri uri, Uri serverEndpointUri, D { if (dict.TryGetValue(ClientSecretProperty, out var clientSecret)) { - return new AadAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri); + return new AccessKeyForMicrosoftEntra(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri); } else if (dict.TryGetValue(ClientCertProperty, out var clientCertPath)) { - return new AadAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri); + return new AccessKeyForMicrosoftEntra(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri); } else { @@ -182,12 +182,12 @@ private static AccessKey BuildAzureADAccessKey(Uri uri, Uri serverEndpointUri, D } else { - return new AadAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri); + return new AccessKeyForMicrosoftEntra(uri, new ManagedIdentityCredential(clientId), serverEndpointUri); } } else { - return new AadAccessKey(uri, new ManagedIdentityCredential(), serverEndpointUri); + return new AccessKeyForMicrosoftEntra(uri, new ManagedIdentityCredential(), serverEndpointUri); } } @@ -200,7 +200,7 @@ private static AccessKey BuildAccessKey(Uri uri, Dictionary dict private static AccessKey BuildAzureAccessKey(Uri uri, Uri serverEndpointUri, Dictionary dict) { - return new AadAccessKey(uri, new DefaultAzureCredential(), serverEndpointUri); + return new AccessKeyForMicrosoftEntra(uri, new DefaultAzureCredential(), serverEndpointUri); } private static AccessKey BuildAzureAppAccessKey(Uri uri, Uri serverEndpointUri, Dictionary dict) @@ -217,11 +217,11 @@ private static AccessKey BuildAzureAppAccessKey(Uri uri, Uri serverEndpointUri, if (dict.TryGetValue(ClientSecretProperty, out var clientSecret)) { - return new AadAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri); + return new AccessKeyForMicrosoftEntra(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri); } else if (dict.TryGetValue(ClientCertProperty, out var clientCertPath)) { - return new AadAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri); + return new AccessKeyForMicrosoftEntra(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri); } throw new ArgumentException(MissingClientSecretProperty, ClientSecretProperty); } @@ -229,8 +229,8 @@ private static AccessKey BuildAzureAppAccessKey(Uri uri, Uri serverEndpointUri, private static AccessKey BuildAzureMsiAccessKey(Uri uri, Uri serverEndpointUri, Dictionary dict) { return dict.TryGetValue(ClientIdProperty, out var clientId) - ? new AadAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri) - : new AadAccessKey(uri, new ManagedIdentityCredential(), serverEndpointUri); + ? new AccessKeyForMicrosoftEntra(uri, new ManagedIdentityCredential(clientId), serverEndpointUri) + : new AccessKeyForMicrosoftEntra(uri, new ManagedIdentityCredential(), serverEndpointUri); } private static Dictionary ToDictionary(string connectionString) diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/RestApiAccessTokenGenerator.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/RestApiAccessTokenGenerator.cs index 29a411a08..7b90eba94 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/RestApiAccessTokenGenerator.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/RestApiAccessTokenGenerator.cs @@ -27,9 +27,9 @@ public RestApiAccessTokenGenerator(AccessKey accessKey, string serverName = null public Task Generate(string audience, TimeSpan? lifetime = null) { - if (_accessKey is AadAccessKey key) + if (_accessKey is AccessKeyForMicrosoftEntra key) { - return key.GenerateAadTokenAsync(); + return key.GetMicrosoftEntraTokenAsync(); } return _accessKey.GenerateAccessTokenAsync( diff --git a/src/Microsoft.Azure.SignalR/EndpointProvider/ServiceEndpointProvider.cs b/src/Microsoft.Azure.SignalR/EndpointProvider/ServiceEndpointProvider.cs index cc0c574d7..9b3b18735 100644 --- a/src/Microsoft.Azure.SignalR/EndpointProvider/ServiceEndpointProvider.cs +++ b/src/Microsoft.Azure.SignalR/EndpointProvider/ServiceEndpointProvider.cs @@ -61,9 +61,9 @@ public string GetClientEndpoint(string hubName, string originalPath, string quer public IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId) { - if (_accessKey is AadAccessKey aadAccessKey) + if (_accessKey is AccessKeyForMicrosoftEntra key) { - return new AadTokenProvider(aadAccessKey); + return new MicrosoftEntraTokenProvider(key); } else if (_accessKey is not null) { diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AadAccessKeyTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeyForMicrosoftEntraTests.cs similarity index 83% rename from test/Microsoft.Azure.SignalR.Common.Tests/Auth/AadAccessKeyTests.cs rename to test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeyForMicrosoftEntraTests.cs index 12a79e553..fb1a6a4f4 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AadAccessKeyTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeyForMicrosoftEntraTests.cs @@ -11,7 +11,7 @@ namespace Microsoft.Azure.SignalR.Common.Tests.Auth { [Collection("Auth")] - public class AadAccessKeyTests + public class AccessKeyForMicrosoftEntraTests { private const string SigningKey = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; @@ -21,8 +21,8 @@ public class AadAccessKeyTests [InlineData("https://a.bc:443", "https://a.bc/api/v1/auth/accessKey")] public void TestConstructor(string endpoint, string expectedAuthorizeUrl) { - var key = new AadAccessKey(new Uri(endpoint), new DefaultAzureCredential()); - Assert.Equal(expectedAuthorizeUrl, key.AuthorizeUrl); + var key = new AccessKeyForMicrosoftEntra(new Uri(endpoint), new DefaultAzureCredential()); + Assert.Equal(expectedAuthorizeUrl, key.GetAccessKeyUrl); } [Fact] @@ -33,7 +33,7 @@ public async Task TestUpdateAccessKey() It.IsAny(), It.IsAny())) .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); - var key = new AadAccessKey(new Uri("http://localhost"), mockCredential.Object); + var key = new AccessKeyForMicrosoftEntra(new Uri("http://localhost"), mockCredential.Object); var audience = "http://localhost/chat"; var claims = Array.Empty(); @@ -66,16 +66,16 @@ public async Task TestUpdateAccessKeyShouldSkip(bool isAuthorized, int timeElaps It.IsAny(), It.IsAny())) .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); - var key = new AadAccessKey(new Uri("http://localhost"), mockCredential.Object); - var isAuthorizedField = typeof(AadAccessKey).GetField("_isAuthorized", BindingFlags.NonPublic | BindingFlags.Instance); + var key = new AccessKeyForMicrosoftEntra(new Uri("http://localhost"), mockCredential.Object); + var isAuthorizedField = typeof(AccessKeyForMicrosoftEntra).GetField("_isAuthorized", BindingFlags.NonPublic | BindingFlags.Instance); isAuthorizedField.SetValue(key, isAuthorized); Assert.Equal(isAuthorized, (bool)isAuthorizedField.GetValue(key)); var lastUpdatedTime = DateTime.UtcNow - TimeSpan.FromMinutes(timeElapsed); - var lastUpdatedTimeField = typeof(AadAccessKey).GetField("_lastUpdatedTime", BindingFlags.NonPublic | BindingFlags.Instance); + var lastUpdatedTimeField = typeof(AccessKeyForMicrosoftEntra).GetField("_lastUpdatedTime", BindingFlags.NonPublic | BindingFlags.Instance); lastUpdatedTimeField.SetValue(key, lastUpdatedTime); - var initializedTcsField = typeof(AadAccessKey).GetField("_initializedTcs", BindingFlags.NonPublic | BindingFlags.Instance); + var initializedTcsField = typeof(AccessKeyForMicrosoftEntra).GetField("_initializedTcs", BindingFlags.NonPublic | BindingFlags.Instance); var initializedTcs = (TaskCompletionSource)initializedTcsField.GetValue(key); var source = new CancellationTokenSource(TimeSpan.FromSeconds(1)); @@ -104,7 +104,7 @@ public async Task TestInitializeFailed() It.IsAny(), It.IsAny())) .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); - var key = new AadAccessKey(new Uri("http://localhost"), mockCredential.Object); + var key = new AccessKeyForMicrosoftEntra(new Uri("http://localhost"), mockCredential.Object); var audience = "http://localhost/chat"; var claims = Array.Empty(); diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AuthUtilityTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AuthUtilityTests.cs index c42e70f4d..1466ee0b9 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AuthUtilityTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AuthUtilityTests.cs @@ -3,13 +3,10 @@ using System; using System.Collections; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Security.Claims; -using System.Text; using Azure.Identity; -using Microsoft.IdentityModel.Tokens; using Xunit; @@ -19,7 +16,9 @@ namespace Microsoft.Azure.SignalR.Common.Tests.Auth public class AuthUtilityTests { private const string Audience = "https://localhost/aspnetclient?hub=testhub"; + private const string SigningKey = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + private static readonly TimeSpan DefaultLifetime = TimeSpan.FromHours(1); [Fact] @@ -42,7 +41,7 @@ public class CachingTestData : IEnumerable public IEnumerator GetEnumerator() { yield return new object[] { new AccessKey("http://localhost:443", SigningKey), true }; - var key = new AadAccessKey(new Uri("http://localhost"), new DefaultAzureCredential()); + var key = new AccessKeyForMicrosoftEntra(new Uri("http://localhost"), new DefaultAzureCredential()); key.UpdateAccessKey("foo", SigningKey); yield return new object[] { key, false }; } diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs index 4bb555f13..01a81c8b6 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs @@ -101,7 +101,7 @@ public void TestAzureApplication(string connectionString) { var r = ConnectionStringParser.Parse(connectionString); - var key = Assert.IsType(r.AccessKey); + var key = Assert.IsType(r.AccessKey); Assert.IsType(key.TokenCredential); Assert.Same(r.Endpoint, r.AccessKey.Endpoint); Assert.Null(r.Version); @@ -148,7 +148,7 @@ internal void TestDefaultAzureCredential(string expectedEndpoint, string connect var r = ConnectionStringParser.Parse(connectionString); Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/')); - var key = Assert.IsType(r.AccessKey); + var key = Assert.IsType(r.AccessKey); Assert.IsType(key.TokenCredential); Assert.Same(r.Endpoint, r.AccessKey.Endpoint); } @@ -165,7 +165,7 @@ internal void TestManagedIdentity(string expectedEndpoint, string connectionStri var r = ConnectionStringParser.Parse(connectionString); Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/')); - var key = Assert.IsType(r.AccessKey); + var key = Assert.IsType(r.AccessKey); Assert.IsType(key.TokenCredential); Assert.Same(r.Endpoint, r.AccessKey.Endpoint); Assert.Null(r.ClientEndpoint); @@ -180,8 +180,8 @@ internal void TestManagedIdentity(string expectedEndpoint, string connectionStri internal void TestAzureADWithServerEndpoint(string connectionString, string expectedAuthorizeUrl) { var r = ConnectionStringParser.Parse(connectionString); - var key = Assert.IsType(r.AccessKey); - Assert.Equal(expectedAuthorizeUrl, key.AuthorizeUrl, StringComparer.OrdinalIgnoreCase); + var key = Assert.IsType(r.AccessKey); + Assert.Equal(expectedAuthorizeUrl, key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); } public class ClientEndpointTestData : IEnumerable diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AzureActiveDirectoryTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraApplicationTests.cs similarity index 79% rename from test/Microsoft.Azure.SignalR.Common.Tests/Auth/AzureActiveDirectoryTests.cs rename to test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraApplicationTests.cs index d91c07139..f00b96933 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AzureActiveDirectoryTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraApplicationTests.cs @@ -1,21 +1,18 @@ using System; using System.IdentityModel.Tokens.Jwt; using System.Threading.Tasks; - using Azure.Core; using Azure.Identity; - using Microsoft.IdentityModel.Logging; using Microsoft.IdentityModel.Protocols; using Microsoft.IdentityModel.Protocols.OpenIdConnect; using Microsoft.IdentityModel.Tokens; - using Xunit; namespace Microsoft.Azure.SignalR.Common.Tests.Auth { [Collection("Auth")] - public class AzureActiveDirectoryTests + public class MicrosoftEntraApplicationTests { private const string IssuerEndpoint = "https://sts.windows.net/"; @@ -25,21 +22,21 @@ public class AzureActiveDirectoryTests private static readonly string[] DefaultScopes = new string[] { "https://signalr.azure.com/.default" }; - [Fact(Skip = "Provide valid aad options")] + [Fact(Skip = "Provide valid Microsoft Entra application options")] public async Task TestAcquireAccessToken() { var options = new ClientSecretCredential(TestTenantId, TestClientId, TestClientSecret); - var key = new AadAccessKey(new Uri("https://localhost:8080"), options); - var token = await key.GenerateAadTokenAsync(); + var key = new AccessKeyForMicrosoftEntra(new Uri("https://localhost:8080"), options); + var token = await key.GetMicrosoftEntraTokenAsync(); Assert.NotNull(token); } - [Fact(Skip = "Provide valid aad options")] - public async Task TestGetAzureAdTokenAndAuthenticate() + [Fact(Skip = "Provide valid Microsoft Entra application options")] + public async Task TestGetMicrosoftEntraTokenAndAuthenticate() { var credential = new ClientSecretCredential(TestTenantId, TestClientId, TestClientSecret); - ConfigurationManager configManager = new ConfigurationManager( + var configManager = new ConfigurationManager( "https://login.microsoftonline.com/common/v2.0/.well-known/openid-configuration", new OpenIdConnectConfigurationRetriever() ); @@ -72,11 +69,11 @@ public async Task TestGetAzureAdTokenAndAuthenticate() Assert.NotNull(validToken); } - [Fact(Skip = "Provide valid aad options")] + [Fact(Skip = "Provide valid Microsoft Entra application options")] internal async Task TestAuthenticateAsync() { var options = new ClientSecretCredential(TestTenantId, TestClientId, TestClientSecret); - var key = new AadAccessKey(new Uri("https://localhost:8080"), options); + var key = new AccessKeyForMicrosoftEntra(new Uri("https://localhost:8080"), options); await key.UpdateAccessKeyAsync(); Assert.True(key.Authorized); @@ -84,4 +81,4 @@ internal async Task TestAuthenticateAsync() Assert.NotNull(key.Value); } } -} \ No newline at end of file +} diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Endpoints/AccessKeySynchronizerFacts.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Endpoints/AccessKeySynchronizerFacts.cs index 0370b5342..6e35d1e5a 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Endpoints/AccessKeySynchronizerFacts.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Endpoints/AccessKeySynchronizerFacts.cs @@ -46,16 +46,16 @@ public void FilterAadAccessKeysTest() var endpoint2 = new ServiceEndpoint($"Endpoint=http://endpoint2.net;AuthType=aad;ClientId=foo;ClientSecret=bar;TenantId={tenantId};Version=1.0"); synchronizer.UpdateServiceEndpoints(new List() { endpoint1 }); - Assert.Empty(synchronizer.FilterAadAccessKeys()); + Assert.Empty(synchronizer.AccessKeyForMicrosoftEntraList); synchronizer.UpdateServiceEndpoints(new List() { endpoint1, endpoint2 }); - Assert.Single(synchronizer.FilterAadAccessKeys()); + Assert.Single(synchronizer.AccessKeyForMicrosoftEntraList); synchronizer.UpdateServiceEndpoints(new List() { endpoint2 }); - Assert.Single(synchronizer.FilterAadAccessKeys()); + Assert.Single(synchronizer.AccessKeyForMicrosoftEntraList); synchronizer.UpdateServiceEndpoints(new List() { }); - Assert.Empty(synchronizer.FilterAadAccessKeys()); + Assert.Empty(synchronizer.AccessKeyForMicrosoftEntraList); } } } diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/ServiceEndpointFacts.cs b/test/Microsoft.Azure.SignalR.Common.Tests/ServiceEndpointFacts.cs index dd89383d1..37446112b 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/ServiceEndpointFacts.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/ServiceEndpointFacts.cs @@ -117,7 +117,7 @@ public void TestAzureADConstructor(string url, string expectedEndpoint, int port { var uri = new Uri(url); var serviceEndpoint = new ServiceEndpoint(uri, new DefaultAzureCredential()); - Assert.IsType(serviceEndpoint.AccessKey); + Assert.IsType(serviceEndpoint.AccessKey); Assert.Equal(expectedEndpoint, serviceEndpoint.Endpoint); Assert.Equal("", serviceEndpoint.Name); Assert.Equal(port, serviceEndpoint.AccessKey.Endpoint.Port); @@ -150,7 +150,7 @@ public void TestAzureADConstructorWithKey(string key, string name, EndpointType { var uri = new Uri("http://localhost"); var serviceEndpoint = new ServiceEndpoint(key, uri, new DefaultAzureCredential()); - Assert.IsType(serviceEndpoint.AccessKey); + Assert.IsType(serviceEndpoint.AccessKey); Assert.Equal(name, serviceEndpoint.Name); Assert.Equal(type, serviceEndpoint.EndpointType); TestCopyConstructor(serviceEndpoint); @@ -166,22 +166,22 @@ public void TestAzureADConstructorWithServerEndpoint() { ServerEndpoint = serverEndpoint1 }; - var key = Assert.IsType(endpoint.AccessKey); + var key = Assert.IsType(endpoint.AccessKey); Assert.Same(key, endpoint.AccessKey); - Assert.Equal("http://serverEndpoint:123/api/v1/auth/accessKey", key.AuthorizeUrl, StringComparer.OrdinalIgnoreCase); + Assert.Equal("http://serverEndpoint:123/api/v1/auth/accessKey", key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); endpoint = new ServiceEndpoint(new Uri(serviceEndpoint), new DefaultAzureCredential(), serverEndpoint: serverEndpoint2); - key = Assert.IsType(endpoint.AccessKey); + key = Assert.IsType(endpoint.AccessKey); Assert.Same(key, endpoint.AccessKey); - Assert.Equal("http://serverEndpoint:123/path/api/v1/auth/accessKey", key.AuthorizeUrl, StringComparer.OrdinalIgnoreCase); + Assert.Equal("http://serverEndpoint:123/path/api/v1/auth/accessKey", key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); endpoint = new ServiceEndpoint(new Uri(serviceEndpoint), new DefaultAzureCredential(), serverEndpoint: serverEndpoint1) { ServerEndpoint = serverEndpoint2 // property initialize should override constructor param. }; - key = Assert.IsType(endpoint.AccessKey); + key = Assert.IsType(endpoint.AccessKey); Assert.Same(key, endpoint.AccessKey); - Assert.Equal("http://serverEndpoint:123/path/api/v1/auth/accessKey", key.AuthorizeUrl, StringComparer.OrdinalIgnoreCase); + Assert.Equal("http://serverEndpoint:123/path/api/v1/auth/accessKey", key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); } [Theory] diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs index ca29456f5..efc0b2d6f 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs @@ -148,7 +148,7 @@ public async Task TestCloseConnectionMessage() [Theory] [InlineData(typeof(AccessKey))] - [InlineData(typeof(AadAccessKey))] + [InlineData(typeof(AccessKeyForMicrosoftEntra))] public async Task TestAccessKeyRequestMessage(Type keyType) { var endpoint = MockServiceEndpoint(keyType.Name); @@ -173,7 +173,7 @@ public async Task TestAccessKeyRequestMessage(Type keyType) [Theory] [InlineData(typeof(AccessKey))] - [InlineData(typeof(AadAccessKey))] + [InlineData(typeof(AccessKeyForMicrosoftEntra))] public async Task TestAccessKeyResponseMessage(Type keyType) { var endpoint = MockServiceEndpoint(keyType.Name); @@ -222,9 +222,9 @@ public async Task TestAccessKeyResponseMessageWithError(int minutesElapsed, int { var endpoint = new TestHubServiceEndpoint(endpoint: new TestServiceEndpoint(new DefaultAzureCredential())); - if (endpoint.AccessKey is AadAccessKey key) + if (endpoint.AccessKey is AccessKeyForMicrosoftEntra key) { - var field = typeof(AadAccessKey).GetField("_lastUpdatedTime", BindingFlags.NonPublic | BindingFlags.Instance); + var field = typeof(AccessKeyForMicrosoftEntra).GetField("_lastUpdatedTime", BindingFlags.NonPublic | BindingFlags.Instance); field.SetValue(key, DateTime.UtcNow - TimeSpan.FromMinutes(minutesElapsed)); } @@ -313,7 +313,7 @@ private ServiceEndpoint MockServiceEndpoint(string keyTypeName) case nameof(AccessKey): return new ServiceEndpoint(_keyConnectionString); - case nameof(AadAccessKey): + case nameof(AccessKeyForMicrosoftEntra): var endpoint = new ServiceEndpoint(_aadConnectionString); var p = typeof(ServiceEndpoint).GetProperty("AccessKey", BindingFlags.NonPublic | BindingFlags.Instance); p.SetValue(endpoint, new TestAadAccessKey()); @@ -324,7 +324,7 @@ private ServiceEndpoint MockServiceEndpoint(string keyTypeName) } } - private class TestAadAccessKey : AadAccessKey + private class TestAadAccessKey : AccessKeyForMicrosoftEntra { public string Token { get; } = Guid.NewGuid().ToString(); @@ -332,7 +332,7 @@ private class TestAadAccessKey : AadAccessKey { } - public override Task GenerateAadTokenAsync(CancellationToken ctoken = default) + public override Task GetMicrosoftEntraTokenAsync(CancellationToken ctoken = default) { return Task.FromResult(Token); }