Skip to content

Commit

Permalink
Add AsyncLazy<T>.SuppressRecursiveFactoryDetection property
Browse files Browse the repository at this point in the history
This allows folks to opt-out of the recursive detection altogether, such that `AsyncLazy<T>.SuppressRelevance()` isn't required (though that call would still carry out the JTF suppression function).
  • Loading branch information
AArnott committed Dec 12, 2023
1 parent 4b1ca81 commit a42d2dd
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 7 deletions.
50 changes: 43 additions & 7 deletions src/Microsoft.VisualStudio.Threading/AsyncLazy`1.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class AsyncLazy<T>
/// <summary>
/// The unique instance identifier.
/// </summary>
private readonly AsyncLocal<object> recursiveFactoryCheck = new AsyncLocal<object>();
private AsyncLocal<object>? recursiveFactoryCheck;

/// <summary>
/// The function to invoke to produce the task.
Expand Down Expand Up @@ -72,6 +72,30 @@ public AsyncLazy(Func<Task<T>> valueFactory, JoinableTaskFactory? joinableTaskFa
this.jobFactory = joinableTaskFactory;
}

/// <summary>
/// Gets a value indicating whether to suppress detection of a value factory depending on itself.
/// </summary>
/// <remarks>
/// <para>
/// A value factory that truly depends on itself (e.g. by calling <see cref="GetValueAsync()"/> on the same instance)
/// would deadlock, and by default this class will throw an exception if it detects such a condition.
/// However this detection relies on the .NET ExecutionContext, which can flow to "spin off" contexts that are not awaited
/// by the factory, and thus could legally await the result of the value factory without deadlocking.
/// </para>
/// <para>
/// When this flows improperly, it can cause <see cref="InvalidOperationException"/> to be thrown, but only when the value factory
/// has not already been completed, leading to a difficult to reproduce race condition.
/// Such a case can be resolved by calling <see cref="SuppressRelevance"/> around the non-awaited fork in <see cref="ExecutionContext" />,
/// or the entire instance can be configured to suppress this check by setting this property to <see langword="true"/>.
/// </para>
/// <para>
/// When this property is set to <see langword="true" />, the recursive factory check will not be performed,
/// but <see cref="SuppressRelevance"/> will still call into <see cref="JoinableTaskContext.SuppressRelevance"/>
/// if a <see cref="JoinableTaskFactory"/> was provided to the constructor.
/// </para>
/// </remarks>
public bool SuppressRecursiveFactoryDetection { get; init; }

/// <summary>
/// Gets a value indicating whether the value factory has been invoked.
/// </summary>
Expand Down Expand Up @@ -137,7 +161,7 @@ public bool IsValueFactoryCompleted
/// <exception cref="ObjectDisposedException">Thrown after <see cref="DisposeValue"/> is called.</exception>
public Task<T> GetValueAsync(CancellationToken cancellationToken)
{
if (!((this.value is object && this.value.IsCompleted) || this.recursiveFactoryCheck.Value is null))
if (this.value is not { IsCompleted: true } && this.recursiveFactoryCheck is { Value: not null })
{
// PERF: we check the condition and *then* retrieve the string resource only on failure
// because the string retrieval has shown up as significant on ETL traces.
Expand Down Expand Up @@ -183,7 +207,12 @@ public Task<T> GetValueAsync(CancellationToken cancellationToken)
}
};

this.recursiveFactoryCheck.Value = RecursiveCheckSentinel;
if (!this.SuppressRecursiveFactoryDetection)
{
Assumes.Null(this.recursiveFactoryCheck);
this.recursiveFactoryCheck = new AsyncLocal<object>() { Value = RecursiveCheckSentinel };
}

try
{
if (this.jobFactory is object)
Expand All @@ -201,7 +230,10 @@ public Task<T> GetValueAsync(CancellationToken cancellationToken)
}
finally
{
this.recursiveFactoryCheck.Value = null;
if (this.recursiveFactoryCheck is not null)
{
this.recursiveFactoryCheck.Value = null;
}
}
}
}
Expand Down Expand Up @@ -451,7 +483,11 @@ internal RevertRelevance(AsyncLazy<T> owner)
Requires.NotNull(owner, nameof(owner));
this.owner = owner;

(this.oldCheckValue, owner.recursiveFactoryCheck.Value) = (owner.recursiveFactoryCheck.Value, null);
if (owner.recursiveFactoryCheck is not null)
{
(this.oldCheckValue, owner.recursiveFactoryCheck.Value) = (owner.recursiveFactoryCheck.Value, null);
}

this.joinableRelevance = owner.jobFactory?.Context.SuppressRelevance();
}

Expand All @@ -460,9 +496,9 @@ internal RevertRelevance(AsyncLazy<T> owner)
/// </summary>
public void Dispose()
{
if (this.owner is object)
if (this.owner?.recursiveFactoryCheck is { } check)
{
this.owner.recursiveFactoryCheck.Value = this.oldCheckValue;
check.Value = this.oldCheckValue;
}

this.joinableRelevance?.Dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ Microsoft.VisualStudio.Threading.AsyncLazy<T>.DisposeValueAsync() -> System.Thre
Microsoft.VisualStudio.Threading.AsyncLazy<T>.IsValueDisposed.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance.Dispose() -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.init -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ Microsoft.VisualStudio.Threading.AsyncLazy<T>.DisposeValueAsync() -> System.Thre
Microsoft.VisualStudio.Threading.AsyncLazy<T>.IsValueDisposed.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance.Dispose() -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.init -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ Microsoft.VisualStudio.Threading.AsyncLazy<T>.DisposeValueAsync() -> System.Thre
Microsoft.VisualStudio.Threading.AsyncLazy<T>.IsValueDisposed.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance.Dispose() -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.init -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ Microsoft.VisualStudio.Threading.AsyncLazy<T>.DisposeValueAsync() -> System.Thre
Microsoft.VisualStudio.Threading.AsyncLazy<T>.IsValueDisposed.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance.Dispose() -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.init -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
102 changes: 102 additions & 0 deletions test/Microsoft.VisualStudio.Threading.Tests/AsyncLazyTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

using Microsoft;
using Microsoft.VisualStudio.Threading;

using Xunit;
using Xunit.Abstractions;

using NamedSyncContext = AwaitExtensionsTests.NamedSyncContext;

public class AsyncLazyTests : TestBase
Expand Down Expand Up @@ -748,6 +751,105 @@ async Task<int> FireAndForgetCodeAsync()
}
}

[Fact]
public async Task SuppressRecursiveFactoryDetection_WithoutJTF()
{
AsyncManualResetEvent allowValueFactoryToFinish = new();
Task<int>? fireAndForgetTask = null;
AsyncLazy<int> asyncLazy = null!;
asyncLazy = new AsyncLazy<int>(
async delegate
{
fireAndForgetTask = FireAndForgetCodeAsync();
await allowValueFactoryToFinish;
return 1;
},
null)
{
SuppressRecursiveFactoryDetection = true,
};

bool fireAndForgetCodeAsyncEntered = false;
Task<int> lazyValue = asyncLazy.GetValueAsync();
Assert.True(fireAndForgetCodeAsyncEntered);
allowValueFactoryToFinish.Set();

// Assert that the value factory was allowed to finish.
Assert.Equal(1, await lazyValue.WithCancellation(this.TimeoutToken));

// Assert that the fire-and-forget task was allowed to finish and did so without throwing.
Assert.Equal(1, await fireAndForgetTask!.WithCancellation(this.TimeoutToken));

async Task<int> FireAndForgetCodeAsync()
{
fireAndForgetCodeAsyncEntered = true;
return await asyncLazy.GetValueAsync();
}
}

[Theory, PairwiseData]
public async Task SuppressRecursiveFactoryDetection_WithJTF(bool suppressWithJTF)
{
JoinableTaskContext? context = this.InitializeJTCAndSC();
SingleThreadedTestSynchronizationContext.IFrame frame = SingleThreadedTestSynchronizationContext.NewFrame();

JoinableTaskFactory? jtf = context.Factory;
AsyncManualResetEvent allowValueFactoryToFinish = new();
Task<int>? fireAndForgetTask = null;
AsyncLazy<int> asyncLazy = null!;
asyncLazy = new AsyncLazy<int>(
async delegate
{
using (suppressWithJTF ? jtf.Context.SuppressRelevance() : default)
using (suppressWithJTF ? default : asyncLazy.SuppressRelevance())
{
fireAndForgetTask = FireAndForgetCodeAsync();
}

await allowValueFactoryToFinish;
return 1;
},
jtf)
{
SuppressRecursiveFactoryDetection = true,
};

bool fireAndForgetCodeAsyncEntered = false;
bool fireAndForgetCodeAsyncReachedUIThread = false;
jtf.Run(async delegate
{
Task<int> lazyValue = asyncLazy.GetValueAsync();
Assert.True(fireAndForgetCodeAsyncEntered);
await Task.Delay(AsyncDelay);
Assert.False(fireAndForgetCodeAsyncReachedUIThread);
allowValueFactoryToFinish.Set();

// Assert that the value factory was allowed to finish.
Assert.Equal(1, await lazyValue.WithCancellation(this.TimeoutToken));
});

// Run a main thread pump so the fire-and-forget task can finish.
SingleThreadedTestSynchronizationContext.PushFrame(SynchronizationContext.Current!, frame);

// Assert that the fire-and-forget task was allowed to finish and did so without throwing.
Assert.Equal(1, await fireAndForgetTask!.WithCancellation(this.TimeoutToken));

async Task<int> FireAndForgetCodeAsync()
{
fireAndForgetCodeAsyncEntered = true;

// Yield the caller's thread.
// Resuming will require the main thread, since the caller was on the main thread.
await Task.Yield();

fireAndForgetCodeAsyncReachedUIThread = true;

int result = await asyncLazy.GetValueAsync();
frame.Continue = false;
return result;
}
}

[Fact]
public async Task Dispose_ValueType_Completed()
{
Expand Down

0 comments on commit a42d2dd

Please sign in to comment.