Skip to content

Commit

Permalink
Lazy load temporary AccessKey
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan committed Nov 22, 2024
1 parent 30d6a43 commit 281c134
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;

namespace Microsoft.Azure.SignalR;

internal sealed class AccessKeySynchronizer : IAccessKeySynchronizer, IDisposable
{
private readonly ConcurrentDictionary<ServiceEndpoint, object> _endpoints = new ConcurrentDictionary<ServiceEndpoint, object>(ReferenceEqualityComparer.Instance);
private readonly ConcurrentDictionary<MicrosoftEntraAccessKey, bool> _keyMap = new(ReferenceEqualityComparer.Instance);

private readonly ILoggerFactory _factory;
private readonly ILogger<AccessKeySynchronizer> _logger;

private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(1));

internal IEnumerable<MicrosoftEntraAccessKey> AccessKeyForMicrosoftEntraList => _endpoints.Select(e => e.Key.AccessKey).OfType<MicrosoftEntraAccessKey>();
internal IEnumerable<MicrosoftEntraAccessKey> InitializedKeyList => _keyMap.Where(x => x.Key.Initialized).Select(x => x.Key);

public AccessKeySynchronizer(ILoggerFactory loggerFactory) : this(loggerFactory, true)
{
Expand All @@ -32,65 +33,73 @@ internal AccessKeySynchronizer(ILoggerFactory loggerFactory, bool start)
{
if (start)
{
_ = UpdateAccessKeyAsync();
_ = UpdateAllAccessKeyAsync();
}
_factory = loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory));
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<AccessKeySynchronizer>();
}

public void AddServiceEndpoint(ServiceEndpoint endpoint)
{
if (endpoint.AccessKey is MicrosoftEntraAccessKey key)
{
_ = key.UpdateAccessKeyAsync();
_keyMap.TryAdd(key, true);
}
_endpoints.TryAdd(endpoint, null);
}

public void Dispose() => _timer.Stop();

public void UpdateServiceEndpoints(IEnumerable<ServiceEndpoint> endpoints)
{
_endpoints.Clear();
_keyMap.Clear();
foreach (var endpoint in endpoints)
{
AddServiceEndpoint(endpoint);
}
}

internal bool ContainsServiceEndpoint(ServiceEndpoint e) => _endpoints.ContainsKey(e);
/// <summary>
/// Test only
/// </summary>
/// <param name="e"></param>
/// <returns></returns>
internal bool ContainsKey(ServiceEndpoint e) => _keyMap.ContainsKey(e.AccessKey as MicrosoftEntraAccessKey);

internal int ServiceEndpointsCount() => _endpoints.Count;
/// <summary>
/// Test only
/// </summary>
/// <returns></returns>
internal int Count() => _keyMap.Count;

private async Task UpdateAccessKeyAsync()
private async Task UpdateAllAccessKeyAsync()
{
using (_timer)
{
_timer.Start();

while (await _timer)
{
foreach (var key in AccessKeyForMicrosoftEntraList)
foreach (var key in InitializedKeyList)
{
_ = key.UpdateAccessKeyAsync();
}
}
}
}

private sealed class ReferenceEqualityComparer : IEqualityComparer<ServiceEndpoint>
private sealed class ReferenceEqualityComparer : IEqualityComparer<MicrosoftEntraAccessKey>
{
internal static readonly ReferenceEqualityComparer Instance = new ReferenceEqualityComparer();

private ReferenceEqualityComparer()
{
}

public bool Equals(ServiceEndpoint x, ServiceEndpoint y)
public bool Equals(MicrosoftEntraAccessKey x, MicrosoftEntraAccessKey y)
{
return ReferenceEquals(x, y);
}

public int GetHashCode(ServiceEndpoint obj)
public int GetHashCode(MicrosoftEntraAccessKey obj)
{
return RuntimeHelpers.GetHashCode(obj);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Collections.Generic;
using System.Threading.Tasks;

namespace Microsoft.Azure.SignalR;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ internal class MicrosoftEntraAccessKey : IAccessKey
{
internal static readonly TimeSpan GetAccessKeyTimeout = TimeSpan.FromSeconds(100);

private const int UpdateTaskIdle = 0;

private const int UpdateTaskRunning = 1;

private const int GetAccessKeyMaxRetryTimes = 3;

private const int GetMicrosoftEntraTokenMaxRetryTimes = 3;
Expand All @@ -36,10 +40,12 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private static readonly TimeSpan AccessKeyExpireTime = TimeSpan.FromMinutes(120);

private readonly TaskCompletionSource<object?> _initializedTcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly TaskCompletionSource<object?> _initializedTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);

private readonly IHttpClientFactory _httpClientFactory;

private volatile int _updateState = 0;

private volatile bool _isAuthorized = false;

private DateTime _updateAt = DateTime.MinValue;
Expand All @@ -48,6 +54,8 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private volatile byte[]? _keyBytes;

public bool Initialized => _initializedTcs.Task.IsCompleted;

public bool Available
{
get => _isAuthorized && DateTime.UtcNow - _updateAt < AccessKeyExpireTime;
Expand Down Expand Up @@ -116,6 +124,11 @@ public async Task<string> GenerateAccessTokenAsync(string audience,
AccessTokenAlgorithm algorithm,
CancellationToken ctoken = default)
{
if (!_initializedTcs.Task.IsCompleted)
{
_ = UpdateAccessKeyAsync();
}

await _initializedTcs.Task.OrCancelAsync(ctoken, "The access key initialization timed out.");

return Available
Expand All @@ -142,14 +155,24 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
return;
}

if (Interlocked.CompareExchange(ref _updateState, UpdateTaskRunning, UpdateTaskIdle) != UpdateTaskIdle)
{
return;
}

for (var i = 0; i < GetAccessKeyMaxRetryTimes; i++)
{
var source = new CancellationTokenSource(GetAccessKeyTimeout);
var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(source.Token, ctoken);
try
{
await UpdateAccessKeyInternalAsync(linkedSource.Token);
return;
var key = await SendAccessKeyRequestAsync(linkedSource.Token);
if (key != null)
{
UpdateAccessKey(key.KeyId, key.AccessKey);
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
return;
}
}
catch (OperationCanceledException e)
{
Expand All @@ -175,6 +198,7 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
// Update the status only when it becomes "not available" due to expiration to refresh updateAt.
Available = false;
}
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
}

private static string GetExceptionMessage(Exception? exception)
Expand Down Expand Up @@ -216,7 +240,7 @@ private static async Task ThrowExceptionOnResponseFailureAsync(HttpRequestMessag
};
}

private async Task UpdateAccessKeyInternalAsync(CancellationToken ctoken)
private async Task<AccessKeyResponse?> SendAccessKeyRequestAsync(CancellationToken ctoken)
{
var accessToken = await GetMicrosoftEntraTokenAsync(ctoken);

Expand All @@ -227,16 +251,18 @@ private async Task UpdateAccessKeyInternalAsync(CancellationToken ctoken)

var response = await httpClient.SendAsync(request, ctoken);

await HandleHttpResponseAsync(response);
var key = await HandleHttpResponseAsync(response);

await ThrowExceptionOnResponseFailureAsync(request, response);

return key;
}

private async Task<bool> HandleHttpResponseAsync(HttpResponseMessage response)
private async Task<AccessKeyResponse?> HandleHttpResponseAsync(HttpResponseMessage response)
{
if (response.StatusCode != HttpStatusCode.OK)
{
return false;
return null;
}

var content = await response.Content.ReadAsStringAsync();
Expand All @@ -250,8 +276,6 @@ private async Task<bool> HandleHttpResponseAsync(HttpResponseMessage response)
{
throw new AzureSignalRException("Missing required <AccessKey> field.");
}

UpdateAccessKey(obj.KeyId, obj.AccessKey);
return true;
return obj;
}
}
3 changes: 2 additions & 1 deletion src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ internal static class Constants

public const string AsrsDefaultScope = "https://signalr.azure.com/.default";


public const int DefaultCloseTimeoutMilliseconds = 10000;

public static class Keys
Expand All @@ -45,6 +44,8 @@ public static class Periods

public const int MaxCustomHandshakeTimeout = 30;

public static readonly TimeSpan DefaultUpdateAccessKeyTimeout = TimeSpan.FromMinutes(2);

public static readonly TimeSpan DefaultAccessTokenLifetime = TimeSpan.FromHours(1);

public static readonly TimeSpan DefaultScaleTimeout = TimeSpan.FromMinutes(5);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using Azure.Identity;
using Microsoft.Azure.SignalR.Tests.Common;
using Microsoft.Extensions.Logging.Abstractions;
using Xunit;

namespace Microsoft.Azure.SignalR.Common.Tests.Auth;

public class AccessKeySynchronizerFacts
{
[Fact]
public void AddAndRemoveServiceEndpointsTest()
{
var synchronizer = GetInstanceForTest();

var credential = new DefaultAzureCredential();
var endpoint1 = new TestServiceEndpoint(credential);
var endpoint2 = new TestServiceEndpoint(credential);

Assert.Equal(0, synchronizer.Count());
synchronizer.UpdateServiceEndpoints([endpoint1]);
Assert.Equal(1, synchronizer.Count());
synchronizer.UpdateServiceEndpoints([endpoint1, endpoint2]);
Assert.Empty(synchronizer.InitializedKeyList);

Assert.Equal(2, synchronizer.Count());
Assert.True(synchronizer.ContainsKey(endpoint1));
Assert.True(synchronizer.ContainsKey(endpoint2));

synchronizer.UpdateServiceEndpoints([endpoint2]);
Assert.Equal(1, synchronizer.Count());
synchronizer.UpdateServiceEndpoints([]);
Assert.Equal(0, synchronizer.Count());
Assert.Empty(synchronizer.InitializedKeyList);
}

private static AccessKeySynchronizer GetInstanceForTest()
{
return new AccessKeySynchronizer(NullLoggerFactory.Instance, false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ public async Task TestUpdateAccessKeyFailedThrowsNotAuthorizedException(AzureSig
.ThrowsAsync(e);
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object)
{
GetAccessKeyRetryInterval = TimeSpan.Zero
GetAccessKeyRetryInterval = TimeSpan.Zero,
};

var audience = "http://localhost/chat";
Expand Down Expand Up @@ -210,6 +210,52 @@ public async Task TestUpdateAccessKeySendRequest(string expectedKeyStr)
Assert.Equal(expectedKeyStr, Encoding.UTF8.GetString(key.KeyBytes));
}

[Fact]
public async Task TestLazyLoadAccessKey()
{
var expectedKeyStr = DefaultSigningKey;
var expectedKid = "foo";
var text = "{" + string.Format("\"AccessKey\": \"{0}\", \"KeyId\": \"{1}\"", expectedKeyStr, expectedKid) + "}";
var httpClientFactory = new TestHttpClientFactory(new HttpResponseMessage(HttpStatusCode.OK)
{
Content = TextHttpContent.From(text),
});

var credential = new TestTokenCredential(TokenType.MicrosoftEntra);
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, credential, httpClientFactory: httpClientFactory);

Assert.False(key.Initialized);

var token = await key.GenerateAccessTokenAsync("https://localhost", [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256);
Assert.NotNull(token);

Assert.True(key.Initialized);
}

[Fact]
public async Task TestLazyLoadAccessKeyFailed()
{
var mockCredential = new Mock<TokenCredential>();
mockCredential.Setup(credential => credential.GetTokenAsync(
It.IsAny<TokenRequestContext>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new Exception());
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object)
{
GetAccessKeyRetryInterval = TimeSpan.FromSeconds(1),
};

Assert.False(key.Initialized);

var task1 = key.GenerateAccessTokenAsync("https://localhost", [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256);
var task2 = key.UpdateAccessKeyAsync();
Assert.True(task2.IsCompleted); // another task is in progress.

await Assert.ThrowsAsync<AzureSignalRAccessTokenNotAuthorizedException>(async () => await task1);

Assert.True(key.Initialized);
}

[Theory]
[InlineData(TokenType.Local)]
[InlineData(TokenType.MicrosoftEntra)]
Expand All @@ -226,7 +272,10 @@ public async Task ThrowUnauthorizedExceptionTest(TokenType tokenType)
endpoint,
new TestTokenCredential(tokenType),
httpClientFactory: new TestHttpClientFactory(message)
);
)
{
GetAccessKeyRetryInterval = TimeSpan.Zero
};

await key.UpdateAccessKeyAsync();

Expand Down
Loading

0 comments on commit 281c134

Please sign in to comment.