Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebrand AAD to Microsoft Entra #1941

Merged
merged 2 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ public string GetServerEndpoint(string hubName)

public IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId)
{
if (_accessKey is AadAccessKey aadAccessKey)
if (_accessKey is AccessKeyForMicrosoftEntra key)
{
return new AadTokenProvider(aadAccessKey);
return new MicrosoftEntraTokenProvider(key);
}
else if (_accessKey is not null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

namespace Microsoft.Azure.SignalR
{
internal class AadTokenProvider : IAccessTokenProvider
internal class MicrosoftEntraTokenProvider : IAccessTokenProvider
{
private readonly AadAccessKey _accessKey;
private readonly AccessKeyForMicrosoftEntra _accessKey;

public AadTokenProvider(AadAccessKey accessKey)
public MicrosoftEntraTokenProvider(AccessKeyForMicrosoftEntra accessKey)
{
_accessKey = accessKey ?? throw new ArgumentNullException(nameof(accessKey));
}

public Task<string> ProvideAsync() => _accessKey.GenerateAadTokenAsync();
public Task<string> ProvideAsync() => _accessKey.GetMicrosoftEntraTokenAsync();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,27 @@

namespace Microsoft.Azure.SignalR
{
internal class AadAccessKey : AccessKey
internal class AccessKeyForMicrosoftEntra : AccessKey
{
internal const int AuthorizeIntervalInMinute = 55;
internal const int GetAccessKeyIntervalInMinute = 55;

internal const int AuthorizeMaxRetryTimes = 3;
internal const int GetAccessKeyMaxRetryTimes = 3;

internal const int AuthorizeRetryIntervalInSec = 3;
internal const int GetAccessKeyRetryIntervalInSec = 3;

internal const int GetTokenMaxRetryTimes = 3;
internal const int GetMicrosoftEntraTokenMaxRetryTimes = 3;

internal static readonly TimeSpan AuthorizeTimeout = TimeSpan.FromSeconds(100);
internal static readonly TimeSpan GetAccessKeyTimeout = TimeSpan.FromSeconds(100);

private const string DefaultScope = "https://signalr.azure.com/.default";

private static readonly TimeSpan AuthorizeInterval = TimeSpan.FromMinutes(AuthorizeIntervalInMinute);

private static readonly TokenRequestContext DefaultRequestContext = new TokenRequestContext(new string[] { DefaultScope });

private static readonly TimeSpan AuthorizeIntervalWhenFailed = TimeSpan.FromMinutes(5);
private static readonly TimeSpan GetAccessKeyInterval = TimeSpan.FromMinutes(GetAccessKeyIntervalInMinute);

private static readonly TimeSpan GetAccessKeyIntervalWhenUnauthorized = TimeSpan.FromMinutes(5);

private static readonly TimeSpan AuthorizeRetryInterval = TimeSpan.FromSeconds(AuthorizeRetryIntervalInSec);
private static readonly TimeSpan GetAccessKeyRetryInterval = TimeSpan.FromSeconds(GetAccessKeyRetryIntervalInSec);

private readonly TaskCompletionSource<object> _initializedTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);

Expand All @@ -58,23 +58,23 @@ private set

public TokenCredential TokenCredential { get; }

internal string AuthorizeUrl { get; }
internal string GetAccessKeyUrl { get; }

internal bool HasExpired => DateTime.UtcNow - _lastUpdatedTime > TimeSpan.FromMinutes(AuthorizeIntervalInMinute * 2);
internal bool HasExpired => DateTime.UtcNow - _lastUpdatedTime > TimeSpan.FromMinutes(GetAccessKeyIntervalInMinute * 2);

private Task<object> InitializedTask => _initializedTcs.Task;

public AadAccessKey(Uri endpoint, TokenCredential credential, Uri serverEndpoint = null) : base(endpoint)
public AccessKeyForMicrosoftEntra(Uri endpoint, TokenCredential credential, Uri serverEndpoint = null) : base(endpoint)
{
var authorizeUri = (serverEndpoint ?? endpoint).Append("/api/v1/auth/accessKey");
AuthorizeUrl = authorizeUri.AbsoluteUri;
GetAccessKeyUrl = authorizeUri.AbsoluteUri;
TokenCredential = credential;
}

public virtual async Task<string> GenerateAadTokenAsync(CancellationToken ctoken = default)
public virtual async Task<string> GetMicrosoftEntraTokenAsync(CancellationToken ctoken = default)
{
Exception latest = null;
for (var i = 0; i < GetTokenMaxRetryTimes; i++)
for (var i = 0; i < GetMicrosoftEntraTokenMaxRetryTimes; i++)
{
try
{
Expand Down Expand Up @@ -125,28 +125,24 @@ internal void UpdateAccessKey(string kid, string accessKey)
internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
{
var delta = DateTime.UtcNow - _lastUpdatedTime;
if (Authorized && delta < AuthorizeInterval)
if (Authorized && delta < GetAccessKeyInterval)
{
return;
}
else if (!Authorized && delta < AuthorizeIntervalWhenFailed)
else if (!Authorized && delta < GetAccessKeyIntervalWhenUnauthorized)
{
return;
}
await AuthorizeWithRetryAsync(ctoken);
}

private async Task AuthorizeWithRetryAsync(CancellationToken ctoken = default)
{
Exception latest = null;
for (var i = 0; i < AuthorizeMaxRetryTimes; i++)
for (var i = 0; i < GetAccessKeyMaxRetryTimes; i++)
{
var source = new CancellationTokenSource(AuthorizeTimeout);
var source = new CancellationTokenSource(GetAccessKeyTimeout);
var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(source.Token, ctoken);
try
{
var token = await GenerateAadTokenAsync(linkedSource.Token);
await AuthorizeWithTokenAsync(token, linkedSource.Token);
var token = await GetMicrosoftEntraTokenAsync(linkedSource.Token);
await GetAccessKeyInternalAsync(token, linkedSource.Token);
return;
}
catch (OperationCanceledException e)
Expand All @@ -159,7 +155,7 @@ private async Task AuthorizeWithRetryAsync(CancellationToken ctoken = default)
latest = e;
try
{
await Task.Delay(AuthorizeRetryInterval, ctoken);
await Task.Delay(GetAccessKeyRetryInterval, ctoken);
}
catch (OperationCanceledException)
{
Expand All @@ -172,9 +168,9 @@ private async Task AuthorizeWithRetryAsync(CancellationToken ctoken = default)
throw latest;
}

private async Task AuthorizeWithTokenAsync(string accessToken, CancellationToken ctoken = default)
private async Task GetAccessKeyInternalAsync(string accessToken, CancellationToken ctoken = default)
{
var api = new RestApiEndpoint(AuthorizeUrl, accessToken);
var api = new RestApiEndpoint(GetAccessKeyUrl, accessToken);

await new RestClient().SendAsync(
api,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ internal partial class AccessKeySynchronizer
private static class Log
{
private static readonly Action<ILogger, string, Exception> _failedAuthorize =
LoggerMessage.Define<string>(LogLevel.Warning, new EventId(2, "FailedAuthorizeAccessKey"), "Failed in authorizing AccessKey for '{endpoint}', will retry in " + AadAccessKey.AuthorizeRetryIntervalInSec + " seconds");
LoggerMessage.Define<string>(LogLevel.Warning, new EventId(2, "FailedAuthorizeAccessKey"), "Failed in authorizing AccessKey for '{endpoint}', will retry in " + SignalR.AccessKeyForMicrosoftEntra.GetAccessKeyRetryIntervalInSec + " seconds");

private static readonly Action<ILogger, string, Exception> _succeedAuthorize =
LoggerMessage.Define<string>(LogLevel.Information, new EventId(3, "SucceedAuthorizeAccessKey"), "Succeed in authorizing AccessKey for '{endpoint}'");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,16 @@ internal partial class AccessKeySynchronizer : IAccessKeySynchronizer, IDisposab

private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(1));

public AccessKeySynchronizer(
ILoggerFactory loggerFactory
) : this(loggerFactory, true)
internal IEnumerable<AccessKeyForMicrosoftEntra> AccessKeyForMicrosoftEntraList => _endpoints.Select(e => e.Key.AccessKey).OfType<AccessKeyForMicrosoftEntra>();

public AccessKeySynchronizer(ILoggerFactory loggerFactory) : this(loggerFactory, true)
{
}

/// <summary>
/// For test only.
/// Test only.
/// </summary>
internal AccessKeySynchronizer(
ILoggerFactory loggerFactory,
bool start
)
internal AccessKeySynchronizer(ILoggerFactory loggerFactory, bool start)
{
if (start)
{
Expand All @@ -42,9 +39,9 @@ bool start

public void AddServiceEndpoint(ServiceEndpoint endpoint)
{
if (endpoint.AccessKey is AadAccessKey aadKey)
if (endpoint.AccessKey is AccessKeyForMicrosoftEntra key)
{
_ = UpdateAccessKeyAsync(aadKey);
_ = UpdateAccessKeyAsync(key);
}
_endpoints.TryAdd(endpoint, null);
}
Expand All @@ -64,8 +61,6 @@ public void UpdateServiceEndpoints(IEnumerable<ServiceEndpoint> endpoints)

internal int ServiceEndpointsCount() => _endpoints.Count;

internal IEnumerable<AadAccessKey> FilterAadAccessKeys() => _endpoints.Select(e => e.Key.AccessKey).OfType<AadAccessKey>();

private async Task UpdateAccessKeyAsync()
{
using (_timer)
Expand All @@ -74,25 +69,25 @@ private async Task UpdateAccessKeyAsync()

while (await _timer)
{
foreach (var key in FilterAadAccessKeys())
foreach (var key in AccessKeyForMicrosoftEntraList)
{
_ = UpdateAccessKeyAsync(key);
}
}
}
}

private async Task UpdateAccessKeyAsync(AadAccessKey key)
private async Task UpdateAccessKeyAsync(AccessKeyForMicrosoftEntra key)
{
var logger = _factory.CreateLogger<AadAccessKey>();
var logger = _factory.CreateLogger<AccessKeyForMicrosoftEntra>();
try
{
await key.UpdateAccessKeyAsync();
Log.SucceedToAuthorizeAccessKey(logger, key.AuthorizeUrl);
Log.SucceedToAuthorizeAccessKey(logger, key.GetAccessKeyUrl);
}
catch (Exception e)
{
Log.FailedToAuthorizeAccessKey(logger, key.AuthorizeUrl, e);
Log.FailedToAuthorizeAccessKey(logger, key.GetAccessKeyUrl, e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ internal AccessKey AccessKey
{
lock (_lock)
{
_accessKey ??= new AadAccessKey(_serviceEndpoint, _tokenCredential, ServerEndpoint);
_accessKey ??= new AccessKeyForMicrosoftEntra(_serviceEndpoint, _tokenCredential, ServerEndpoint);
}
}
return _accessKey;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,10 @@ public async Task StartAsync(string target = null)
TimerAwaitable syncTimer = null;
try
{
if (HubEndpoint != null && HubEndpoint.AccessKey is AadAccessKey aadKey)
if (HubEndpoint != null && HubEndpoint.AccessKey is AccessKeyForMicrosoftEntra key)
{
syncTimer = new TimerAwaitable(TimeSpan.Zero, DefaultSyncAzureIdentityInterval);
_ = UpdateAzureIdentityAsync(aadKey, syncTimer);
_ = UpdateAzureIdentityAsync(key, syncTimer);
}
await ProcessIncomingAsync(connection);
}
Expand Down Expand Up @@ -330,7 +330,7 @@ private Task OnEventMessageAsync(ServiceEventMessage message)

private Task OnAccessKeyMessageAsync(AccessKeyResponseMessage keyMessage)
{
if (HubEndpoint.AccessKey is AadAccessKey key)
if (HubEndpoint.AccessKey is AccessKeyForMicrosoftEntra key)
{
if (string.IsNullOrEmpty(keyMessage.ErrorType))
{
Expand Down Expand Up @@ -474,7 +474,7 @@ private async Task<bool> ReceiveHandshakeResponseAsync(PipeReader input, Cancell
}
}

private async Task UpdateAzureIdentityAsync(AadAccessKey key, TimerAwaitable timer)
private async Task UpdateAzureIdentityAsync(AccessKeyForMicrosoftEntra key, TimerAwaitable timer)
{
using (timer)
{
Expand All @@ -486,12 +486,12 @@ private async Task UpdateAzureIdentityAsync(AadAccessKey key, TimerAwaitable tim
}
}

private async Task SendAccessKeyRequestMessageAsync(AadAccessKey key)
private async Task SendAccessKeyRequestMessageAsync(AccessKeyForMicrosoftEntra key)
{
try
{
var source = new CancellationTokenSource(AadAccessKey.AuthorizeTimeout);
var token = await key.GenerateAadTokenAsync(source.Token);
var source = new CancellationTokenSource(AccessKeyForMicrosoftEntra.GetAccessKeyTimeout);
var token = await key.GetMicrosoftEntraTokenAsync(source.Token);
var message = new AccessKeyRequestMessage(token);
await SafeWriteAsync(message);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,11 @@ private static AccessKey BuildAzureADAccessKey(Uri uri, Uri serverEndpointUri, D
{
if (dict.TryGetValue(ClientSecretProperty, out var clientSecret))
{
return new AadAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri);
return new AccessKeyForMicrosoftEntra(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri);
}
else if (dict.TryGetValue(ClientCertProperty, out var clientCertPath))
{
return new AadAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri);
return new AccessKeyForMicrosoftEntra(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri);
}
else
{
Expand All @@ -182,12 +182,12 @@ private static AccessKey BuildAzureADAccessKey(Uri uri, Uri serverEndpointUri, D
}
else
{
return new AadAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri);
return new AccessKeyForMicrosoftEntra(uri, new ManagedIdentityCredential(clientId), serverEndpointUri);
}
}
else
{
return new AadAccessKey(uri, new ManagedIdentityCredential(), serverEndpointUri);
return new AccessKeyForMicrosoftEntra(uri, new ManagedIdentityCredential(), serverEndpointUri);
}
}

Expand All @@ -200,7 +200,7 @@ private static AccessKey BuildAccessKey(Uri uri, Dictionary<string, string> dict

private static AccessKey BuildAzureAccessKey(Uri uri, Uri serverEndpointUri, Dictionary<string, string> dict)
{
return new AadAccessKey(uri, new DefaultAzureCredential(), serverEndpointUri);
return new AccessKeyForMicrosoftEntra(uri, new DefaultAzureCredential(), serverEndpointUri);
}

private static AccessKey BuildAzureAppAccessKey(Uri uri, Uri serverEndpointUri, Dictionary<string, string> dict)
Expand All @@ -217,20 +217,20 @@ private static AccessKey BuildAzureAppAccessKey(Uri uri, Uri serverEndpointUri,

if (dict.TryGetValue(ClientSecretProperty, out var clientSecret))
{
return new AadAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri);
return new AccessKeyForMicrosoftEntra(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri);
}
else if (dict.TryGetValue(ClientCertProperty, out var clientCertPath))
{
return new AadAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri);
return new AccessKeyForMicrosoftEntra(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri);
}
throw new ArgumentException(MissingClientSecretProperty, ClientSecretProperty);
}

private static AccessKey BuildAzureMsiAccessKey(Uri uri, Uri serverEndpointUri, Dictionary<string, string> dict)
{
return dict.TryGetValue(ClientIdProperty, out var clientId)
? new AadAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri)
: new AadAccessKey(uri, new ManagedIdentityCredential(), serverEndpointUri);
? new AccessKeyForMicrosoftEntra(uri, new ManagedIdentityCredential(clientId), serverEndpointUri)
: new AccessKeyForMicrosoftEntra(uri, new ManagedIdentityCredential(), serverEndpointUri);
}

private static Dictionary<string, string> ToDictionary(string connectionString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ public RestApiAccessTokenGenerator(AccessKey accessKey, string serverName = null

public Task<string> Generate(string audience, TimeSpan? lifetime = null)
{
if (_accessKey is AadAccessKey key)
if (_accessKey is AccessKeyForMicrosoftEntra key)
{
return key.GenerateAadTokenAsync();
return key.GetMicrosoftEntraTokenAsync();
}

return _accessKey.GenerateAccessTokenAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ public string GetClientEndpoint(string hubName, string originalPath, string quer

public IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId)
{
if (_accessKey is AadAccessKey aadAccessKey)
if (_accessKey is AccessKeyForMicrosoftEntra key)
{
return new AadTokenProvider(aadAccessKey);
return new MicrosoftEntraTokenProvider(key);
}
else if (_accessKey is not null)
{
Expand Down
Loading
Loading