Skip to content

Commit

Permalink
Add CancellationToken support during refresh operations
Browse files Browse the repository at this point in the history
  • Loading branch information
avanigupta committed Oct 18, 2021
1 parent 3c8b125 commit 61322dc
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ public override void Load()
_isInitialLoadComplete = true;
}

public async Task RefreshAsync()
public async Task RefreshAsync(CancellationToken cancellationToken)
{
// Ensure that concurrent threads do not simultaneously execute refresh operation.
if (Interlocked.Exchange(ref _networkOperationsInProgress, 1) == 0)
Expand All @@ -174,15 +174,15 @@ public async Task RefreshAsync()
if (InitializationCacheExpires < DateTimeOffset.UtcNow)
{
InitializationCacheExpires = DateTimeOffset.UtcNow.Add(MinCacheExpirationInterval);
await LoadAll(ignoreFailures: false).ConfigureAwait(false);
await LoadAll(ignoreFailures: false, cancellationToken).ConfigureAwait(false);
}

return;
}

await RefreshIndividualKeyValues().ConfigureAwait(false);
await RefreshKeyValueCollections().ConfigureAwait(false);
await RefreshKeyValueAdapters().ConfigureAwait(false);
await RefreshIndividualKeyValues(cancellationToken).ConfigureAwait(false);
await RefreshKeyValueCollections(cancellationToken).ConfigureAwait(false);
await RefreshKeyValueAdapters(cancellationToken).ConfigureAwait(false);
}
finally
{
Expand All @@ -191,11 +191,11 @@ public async Task RefreshAsync()
}
}

public async Task<bool> TryRefreshAsync()
public async Task<bool> TryRefreshAsync(CancellationToken cancellationToken)
{
try
{
await RefreshAsync().ConfigureAwait(false);
await RefreshAsync(cancellationToken).ConfigureAwait(false);
}
catch (RequestFailedException e)
{
Expand Down Expand Up @@ -252,7 +252,7 @@ public void SetDirty(TimeSpan? maxDelay)
}
}

private async Task LoadAll(bool ignoreFailures)
private async Task LoadAll(bool ignoreFailures, CancellationToken cancellationToken = default)
{
IDictionary<string, ConfigurationSetting> data = null;
string cachedData = null;
Expand All @@ -275,7 +275,7 @@ private async Task LoadAll(bool ignoreFailures)

await CallWithRequestTracing(async () =>
{
await foreach (ConfigurationSetting setting in _client.GetConfigurationSettingsAsync(selector, CancellationToken.None).ConfigureAwait(false))
await foreach (ConfigurationSetting setting in _client.GetConfigurationSettingsAsync(selector, cancellationToken).ConfigureAwait(false))
{
serverData[setting.Key] = setting;
}
Expand Down Expand Up @@ -303,15 +303,15 @@ await CallWithRequestTracing(async () =>

await CallWithRequestTracing(async () =>
{
await foreach (ConfigurationSetting setting in _client.GetConfigurationSettingsAsync(selector, CancellationToken.None).ConfigureAwait(false))
await foreach (ConfigurationSetting setting in _client.GetConfigurationSettingsAsync(selector, cancellationToken).ConfigureAwait(false))
{
serverData[setting.Key] = setting;
}
}).ConfigureAwait(false);
}

// Block current thread for the initial load of key-values registered for refresh that are not already loaded
await Task.Run(() => LoadKeyValuesRegisteredForRefresh(serverData).ConfigureAwait(false).GetAwaiter().GetResult()).ConfigureAwait(false);
await Task.Run(() => LoadKeyValuesRegisteredForRefresh(serverData, cancellationToken).ConfigureAwait(false).GetAwaiter().GetResult()).ConfigureAwait(false);
data = serverData;
}
catch (Exception exception) when (exception is RequestFailedException ||
Expand Down Expand Up @@ -344,7 +344,7 @@ await CallWithRequestTracing(async () =>
adapter.InvalidateCache();
}

await SetData(data, ignoreFailures).ConfigureAwait(false);
await SetData(data, ignoreFailures, cancellationToken).ConfigureAwait(false);

// Set the cache expiration time for all refresh registered settings
var initialLoadTime = DateTimeOffset.UtcNow;
Expand All @@ -366,7 +366,7 @@ await CallWithRequestTracing(async () =>
}
}

private async Task LoadKeyValuesRegisteredForRefresh(IDictionary<string, ConfigurationSetting> data)
private async Task LoadKeyValuesRegisteredForRefresh(IDictionary<string, ConfigurationSetting> data, CancellationToken cancellationToken = default)
{
_watchedSettings.Clear();

Expand All @@ -388,7 +388,7 @@ private async Task LoadKeyValuesRegisteredForRefresh(IDictionary<string, Configu
ConfigurationSetting watchedKv = null;
try
{
await CallWithRequestTracing(async () => watchedKv = await _client.GetConfigurationSettingAsync(watchedKey, watchedLabel, CancellationToken.None)).ConfigureAwait(false);
await CallWithRequestTracing(async () => watchedKv = await _client.GetConfigurationSettingAsync(watchedKey, watchedLabel, cancellationToken)).ConfigureAwait(false);
}
catch (RequestFailedException e) when (e.Status == (int)HttpStatusCode.NotFound)
{
Expand All @@ -404,7 +404,7 @@ private async Task LoadKeyValuesRegisteredForRefresh(IDictionary<string, Configu
}
}

private async Task RefreshIndividualKeyValues()
private async Task RefreshIndividualKeyValues(CancellationToken cancellationToken = default)
{
bool shouldRefreshAll = false;

Expand All @@ -426,7 +426,7 @@ private async Task RefreshIndividualKeyValues()
{
KeyValueChange keyValueChange = default;
await TracingUtils.CallWithRequestTracing(_requestTracingEnabled, RequestType.Watch, _requestTracingOptions,
async () => keyValueChange = await _client.GetKeyValueChange(watchedKv, CancellationToken.None).ConfigureAwait(false)).ConfigureAwait(false);
async () => keyValueChange = await _client.GetKeyValueChange(watchedKv, cancellationToken).ConfigureAwait(false)).ConfigureAwait(false);

changeWatcher.CacheExpires = DateTimeOffset.UtcNow.Add(changeWatcher.CacheExpirationInterval);

Expand Down Expand Up @@ -459,7 +459,7 @@ await TracingUtils.CallWithRequestTracing(_requestTracingEnabled, RequestType.Wa

try
{
await CallWithRequestTracing(async () => watchedKv = await _client.GetConfigurationSettingAsync(watchedKey, watchedLabel, CancellationToken.None).ConfigureAwait(false)).ConfigureAwait(false);
await CallWithRequestTracing(async () => watchedKv = await _client.GetConfigurationSettingAsync(watchedKey, watchedLabel, cancellationToken).ConfigureAwait(false)).ConfigureAwait(false);
}
catch (RequestFailedException e) when (e.Status == (int)HttpStatusCode.NotFound)
{
Expand All @@ -479,40 +479,40 @@ await TracingUtils.CallWithRequestTracing(_requestTracingEnabled, RequestType.Wa

hasChanged = true;

// Add the key-value if it is not loaded, or update it if it was loaded with a different label
_applicationSettings[watchedKey] = watchedKv;
_watchedSettings[watchedKeyLabel] = watchedKv;
// Add the key-value if it is not loaded, or update it if it was loaded with a different label
_applicationSettings[watchedKey] = watchedKv;
_watchedSettings[watchedKeyLabel] = watchedKv;

// Invalidate the cached Key Vault secret (if any) for this ConfigurationSetting
foreach (IKeyValueAdapter adapter in _options.Adapters)
{
adapter.InvalidateCache(watchedKv);
}
// Invalidate the cached Key Vault secret (if any) for this ConfigurationSetting
foreach (IKeyValueAdapter adapter in _options.Adapters)
{
adapter.InvalidateCache(watchedKv);
}
}
}

if (hasChanged)
{
await SetData(_applicationSettings).ConfigureAwait(false);
await SetData(_applicationSettings, false, cancellationToken).ConfigureAwait(false);
}
}

// Trigger a single refresh-all operation if a change was detected in one or more key-values with refreshAll: true
if (shouldRefreshAll)
{
await LoadAll(ignoreFailures: false).ConfigureAwait(false);
await LoadAll(ignoreFailures: false, cancellationToken).ConfigureAwait(false);
}
}

private async Task RefreshKeyValueAdapters()
private async Task RefreshKeyValueAdapters(CancellationToken cancellationToken = default)
{
if (_options.Adapters.Any(adapter => adapter.NeedsRefresh()))
{
SetData(_applicationSettings);
SetData(_applicationSettings, false, cancellationToken);
}
}

private async Task RefreshKeyValueCollections()
private async Task RefreshKeyValueCollections(CancellationToken cancellationToken = default)
{
foreach (KeyValueWatcher changeWatcher in _options.MultiKeyWatchers)
{
Expand Down Expand Up @@ -541,21 +541,24 @@ private async Task RefreshKeyValueCollections()
});
}

IEnumerable<KeyValueChange> keyValueChanges = await _client.GetKeyValueChangeCollection(currentKeyValues, new GetKeyValueChangeCollectionOptions
{
KeyFilter = changeWatcher.Key,
Label = changeWatcher.Label.NormalizeNull(),
RequestTracingEnabled = _requestTracingEnabled,
RequestTracingOptions = _requestTracingOptions
}).ConfigureAwait(false);
IEnumerable<KeyValueChange> keyValueChanges = await _client.GetKeyValueChangeCollection(
currentKeyValues,
new GetKeyValueChangeCollectionOptions
{
KeyFilter = changeWatcher.Key,
Label = changeWatcher.Label.NormalizeNull(),
RequestTracingEnabled = _requestTracingEnabled,
RequestTracingOptions = _requestTracingOptions
},
cancellationToken).ConfigureAwait(false);

changeWatcher.CacheExpires = DateTimeOffset.UtcNow.Add(changeWatcher.CacheExpirationInterval);

if (keyValueChanges?.Any() == true)
{
ProcessChanges(keyValueChanges);

await SetData(_applicationSettings).ConfigureAwait(false);
await SetData(_applicationSettings, false, cancellationToken).ConfigureAwait(false);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//
using Microsoft.Extensions.Logging;
using System;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Extensions.Configuration.AzureAppConfiguration
Expand Down Expand Up @@ -32,20 +33,20 @@ public void SetProvider(AzureAppConfigurationProvider provider)
AppConfigurationEndpoint = _provider.AppConfigurationEndpoint;
}

public async Task RefreshAsync()
public async Task RefreshAsync(CancellationToken cancellationToken)
{
ThrowIfNullProvider(nameof(RefreshAsync));
await _provider.RefreshAsync().ConfigureAwait(false);
await _provider.RefreshAsync(cancellationToken).ConfigureAwait(false);
}

public async Task<bool> TryRefreshAsync()
public async Task<bool> TryRefreshAsync(CancellationToken cancellationToken)
{
if (_provider == null)
{
return false;
}

return await _provider.TryRefreshAsync().ConfigureAwait(false);
return await _provider.TryRefreshAsync(cancellationToken).ConfigureAwait(false);
}

public void SetDirty(TimeSpan? maxDelay)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public static async Task<KeyValueChange> GetKeyValueChange(this ConfigurationCli
};
}

public static async Task<IEnumerable<KeyValueChange>> GetKeyValueChangeCollection(this ConfigurationClient client, IEnumerable<ConfigurationSetting> keyValues, GetKeyValueChangeCollectionOptions options)
public static async Task<IEnumerable<KeyValueChange>> GetKeyValueChangeCollection(this ConfigurationClient client, IEnumerable<ConfigurationSetting> keyValues, GetKeyValueChangeCollectionOptions options, CancellationToken cancellationToken = default)
{
if (options == null)
{
Expand Down Expand Up @@ -107,7 +107,7 @@ public static async Task<IEnumerable<KeyValueChange>> GetKeyValueChangeCollectio
await TracingUtils.CallWithRequestTracing(options.RequestTracingEnabled, RequestType.Watch, options.RequestTracingOptions,
async () =>
{
await foreach(ConfigurationSetting setting in client.GetConfigurationSettingsAsync(selector).ConfigureAwait(false))
await foreach(ConfigurationSetting setting in client.GetConfigurationSettingsAsync(selector, cancellationToken).ConfigureAwait(false))
{
if (!eTagMap.TryGetValue(setting.Key, out ETag etag) || !etag.Equals(setting.ETag))
{
Expand Down Expand Up @@ -140,7 +140,7 @@ await TracingUtils.CallWithRequestTracing(options.RequestTracingEnabled, Request
await TracingUtils.CallWithRequestTracing(options.RequestTracingEnabled, RequestType.Watch, options.RequestTracingOptions,
async () =>
{
await foreach (ConfigurationSetting setting in client.GetConfigurationSettingsAsync(selector).ConfigureAwait(false))
await foreach (ConfigurationSetting setting in client.GetConfigurationSettingsAsync(selector, cancellationToken).ConfigureAwait(false))
{
if (!eTagMap.TryGetValue(setting.Key, out ETag etag) || !etag.Equals(setting.ETag))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Azure;
using Microsoft.Extensions.Logging;
using System;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Extensions.Configuration.AzureAppConfiguration
Expand All @@ -26,18 +27,20 @@ public interface IConfigurationRefresher
/// <summary>
/// Refreshes the data from App Configuration asynchronously.
/// </summary>
/// <param name="cancellationToken">The cancellation token to cancel the operation.</param>
/// <exception cref="KeyVaultReferenceException">An error occurred when resolving a reference to an Azure Key Vault resource.</exception>
/// <exception cref="RequestFailedException">The request failed with an error code from the server.</exception>
/// <exception cref="AggregateException">
/// The refresh operation failed with one or more errors. Check <see cref="AggregateException.InnerExceptions"/> for more details.
/// </exception>
/// <exception cref="InvalidOperationException">The refresh operation was invoked before Azure App Configuration Provider was initialized.</exception>
Task RefreshAsync();
Task RefreshAsync(CancellationToken cancellationToken = default);

/// <summary>
/// Refreshes the data from App Configuration asynchronously. A return value indicates whether the operation succeeded.
/// </summary>
Task<bool> TryRefreshAsync();
/// <param name="cancellationToken">The cancellation token to cancel the operation.</param>
Task<bool> TryRefreshAsync(CancellationToken cancellationToken = default);

/// <summary>
/// Sets the cached value for key-values registered for refresh as dirty.
Expand Down
44 changes: 44 additions & 0 deletions tests/Tests.AzureAppConfiguration/RefreshTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,40 @@ public void RefreshTests_ConfigureRefreshThrowsOnNoRegistration()
Assert.Throws<ArgumentException>(action);
}

[Fact]
public void RefreshTests_RefreshIsCancelled()
{
IConfigurationRefresher refresher = null;
var mockClient = GetMockConfigurationClient();

var config = new ConfigurationBuilder()
.AddAzureAppConfiguration(options =>
{
options.Client = mockClient.Object;
options.ConfigureRefresh(refreshOptions =>
{
refreshOptions.Register("TestKey1", "label")
.SetCacheExpiration(TimeSpan.FromSeconds(1));
});

refresher = options.GetRefresher();
})
.Build();

Assert.Equal("TestValue1", config["TestKey1"]);
FirstKeyValue.Value = "newValue1";

// Wait for the cache to expire
Thread.Sleep(1500);

using var cancellationSource = new CancellationTokenSource();
cancellationSource.CancelAfter(TimeSpan.Zero);
Action action = () => refresher.RefreshAsync(cancellationSource.Token).Wait();
var exception = Assert.Throws<AggregateException>(action);
Assert.IsType<TaskCanceledException>(exception.InnerException);
Assert.Equal("TestValue1", config["TestKey1"]);
}

private void WaitAndRefresh(IConfigurationRefresher refresher, int millisecondsDelay)
{
Task.Delay(millisecondsDelay).Wait();
Expand All @@ -1145,11 +1179,21 @@ private Mock<ConfigurationClient> GetMockConfigurationClient()

Response<ConfigurationSetting> GetTestKey(string key, string label, CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
cancellationToken.ThrowIfCancellationRequested();
}

return Response.FromValue(TestHelpers.CloneSetting(_kvCollection.FirstOrDefault(s => s.Key == key && s.Label == label)), mockResponse.Object);
}

Response<ConfigurationSetting> GetIfChanged(ConfigurationSetting setting, bool onlyIfChanged, CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
cancellationToken.ThrowIfCancellationRequested();
}

var newSetting = _kvCollection.FirstOrDefault(s => (s.Key == setting.Key && s.Label == setting.Label));
var unchanged = (newSetting.Key == setting.Key && newSetting.Label == setting.Label && newSetting.Value == setting.Value);
var response = new MockResponse(unchanged ? 304 : 200);
Expand Down

0 comments on commit 61322dc

Please sign in to comment.