Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AsyncLazy<T>.SuppressRecursiveFactoryDetection property #1265

Merged
merged 1 commit into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions src/Microsoft.VisualStudio.Threading/AsyncLazy`1.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class AsyncLazy<T>
/// <summary>
/// The unique instance identifier.
/// </summary>
private readonly AsyncLocal<object> recursiveFactoryCheck = new AsyncLocal<object>();
private AsyncLocal<object>? recursiveFactoryCheck;

/// <summary>
/// The function to invoke to produce the task.
Expand Down Expand Up @@ -72,6 +72,31 @@ public AsyncLazy(Func<Task<T>> valueFactory, JoinableTaskFactory? joinableTaskFa
this.jobFactory = joinableTaskFactory;
}

/// <summary>
/// Gets a value indicating whether to suppress detection of a value factory depending on itself.
/// </summary>
/// <value>The default value is <see langword="false" />.</value>
/// <remarks>
/// <para>
/// A value factory that truly depends on itself (e.g. by calling <see cref="GetValueAsync()"/> on the same instance)
/// would deadlock, and by default this class will throw an exception if it detects such a condition.
/// However this detection relies on the .NET ExecutionContext, which can flow to "spin off" contexts that are not awaited
/// by the factory, and thus could legally await the result of the value factory without deadlocking.
/// </para>
/// <para>
/// When this flows improperly, it can cause <see cref="InvalidOperationException"/> to be thrown, but only when the value factory
/// has not already been completed, leading to a difficult to reproduce race condition.
/// Such a case can be resolved by calling <see cref="SuppressRelevance"/> around the non-awaited fork in <see cref="ExecutionContext" />,
/// or the entire instance can be configured to suppress this check by setting this property to <see langword="true"/>.
/// </para>
/// <para>
/// When this property is set to <see langword="true" />, the recursive factory check will not be performed,
/// but <see cref="SuppressRelevance"/> will still call into <see cref="JoinableTaskContext.SuppressRelevance"/>
/// if a <see cref="JoinableTaskFactory"/> was provided to the constructor.
/// </para>
/// </remarks>
public bool SuppressRecursiveFactoryDetection { get; init; }
AArnott marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Gets a value indicating whether the value factory has been invoked.
/// </summary>
Expand Down Expand Up @@ -137,7 +162,7 @@ public bool IsValueFactoryCompleted
/// <exception cref="ObjectDisposedException">Thrown after <see cref="DisposeValue"/> is called.</exception>
public Task<T> GetValueAsync(CancellationToken cancellationToken)
{
if (!((this.value is object && this.value.IsCompleted) || this.recursiveFactoryCheck.Value is null))
if (this.value is not { IsCompleted: true } && this.recursiveFactoryCheck is { Value: not null })
{
// PERF: we check the condition and *then* retrieve the string resource only on failure
// because the string retrieval has shown up as significant on ETL traces.
Expand Down Expand Up @@ -183,7 +208,12 @@ public Task<T> GetValueAsync(CancellationToken cancellationToken)
}
};

this.recursiveFactoryCheck.Value = RecursiveCheckSentinel;
if (!this.SuppressRecursiveFactoryDetection)
{
Assumes.Null(this.recursiveFactoryCheck);
this.recursiveFactoryCheck = new AsyncLocal<object>() { Value = RecursiveCheckSentinel };
}

AArnott marked this conversation as resolved.
Show resolved Hide resolved
try
{
if (this.jobFactory is object)
Expand All @@ -201,7 +231,10 @@ public Task<T> GetValueAsync(CancellationToken cancellationToken)
}
finally
{
this.recursiveFactoryCheck.Value = null;
if (this.recursiveFactoryCheck is not null)
{
this.recursiveFactoryCheck.Value = null;
}
}
}
}
Expand Down Expand Up @@ -451,7 +484,11 @@ internal RevertRelevance(AsyncLazy<T> owner)
Requires.NotNull(owner, nameof(owner));
this.owner = owner;

(this.oldCheckValue, owner.recursiveFactoryCheck.Value) = (owner.recursiveFactoryCheck.Value, null);
if (owner.recursiveFactoryCheck is not null)
{
(this.oldCheckValue, owner.recursiveFactoryCheck.Value) = (owner.recursiveFactoryCheck.Value, null);
}

this.joinableRelevance = owner.jobFactory?.Context.SuppressRelevance();
}

Expand All @@ -460,9 +497,9 @@ internal RevertRelevance(AsyncLazy<T> owner)
/// </summary>
public void Dispose()
{
if (this.owner is object)
if (this.owner?.recursiveFactoryCheck is { } check)
{
this.owner.recursiveFactoryCheck.Value = this.oldCheckValue;
check.Value = this.oldCheckValue;
}

this.joinableRelevance?.Dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ Microsoft.VisualStudio.Threading.AsyncLazy<T>.DisposeValueAsync() -> System.Thre
Microsoft.VisualStudio.Threading.AsyncLazy<T>.IsValueDisposed.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance.Dispose() -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.init -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ Microsoft.VisualStudio.Threading.AsyncLazy<T>.DisposeValueAsync() -> System.Thre
Microsoft.VisualStudio.Threading.AsyncLazy<T>.IsValueDisposed.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance.Dispose() -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.init -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ Microsoft.VisualStudio.Threading.AsyncLazy<T>.DisposeValueAsync() -> System.Thre
Microsoft.VisualStudio.Threading.AsyncLazy<T>.IsValueDisposed.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance.Dispose() -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.init -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ Microsoft.VisualStudio.Threading.AsyncLazy<T>.DisposeValueAsync() -> System.Thre
Microsoft.VisualStudio.Threading.AsyncLazy<T>.IsValueDisposed.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance.Dispose() -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.init -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
102 changes: 102 additions & 0 deletions test/Microsoft.VisualStudio.Threading.Tests/AsyncLazyTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

using Microsoft;
using Microsoft.VisualStudio.Threading;

using Xunit;
using Xunit.Abstractions;

using NamedSyncContext = AwaitExtensionsTests.NamedSyncContext;

public class AsyncLazyTests : TestBase
Expand Down Expand Up @@ -748,6 +751,105 @@ async Task<int> FireAndForgetCodeAsync()
}
}

[Fact]
public async Task SuppressRecursiveFactoryDetection_WithoutJTF()
{
AsyncManualResetEvent allowValueFactoryToFinish = new();
Task<int>? fireAndForgetTask = null;
AsyncLazy<int> asyncLazy = null!;
asyncLazy = new AsyncLazy<int>(
async delegate
{
fireAndForgetTask = FireAndForgetCodeAsync();
await allowValueFactoryToFinish;
return 1;
},
null)
{
SuppressRecursiveFactoryDetection = true,
};

bool fireAndForgetCodeAsyncEntered = false;
Task<int> lazyValue = asyncLazy.GetValueAsync();
Assert.True(fireAndForgetCodeAsyncEntered);
allowValueFactoryToFinish.Set();

// Assert that the value factory was allowed to finish.
Assert.Equal(1, await lazyValue.WithCancellation(this.TimeoutToken));

// Assert that the fire-and-forget task was allowed to finish and did so without throwing.
Assert.Equal(1, await fireAndForgetTask!.WithCancellation(this.TimeoutToken));

async Task<int> FireAndForgetCodeAsync()
{
fireAndForgetCodeAsyncEntered = true;
return await asyncLazy.GetValueAsync();
}
}

[Theory, PairwiseData]
public async Task SuppressRecursiveFactoryDetection_WithJTF(bool suppressWithJTF)
{
JoinableTaskContext? context = this.InitializeJTCAndSC();
SingleThreadedTestSynchronizationContext.IFrame frame = SingleThreadedTestSynchronizationContext.NewFrame();

JoinableTaskFactory? jtf = context.Factory;
AsyncManualResetEvent allowValueFactoryToFinish = new();
Task<int>? fireAndForgetTask = null;
AsyncLazy<int> asyncLazy = null!;
asyncLazy = new AsyncLazy<int>(
async delegate
{
using (suppressWithJTF ? jtf.Context.SuppressRelevance() : default)
using (suppressWithJTF ? default : asyncLazy.SuppressRelevance())
{
fireAndForgetTask = FireAndForgetCodeAsync();
}

await allowValueFactoryToFinish;
return 1;
},
jtf)
{
SuppressRecursiveFactoryDetection = true,
};

bool fireAndForgetCodeAsyncEntered = false;
bool fireAndForgetCodeAsyncReachedUIThread = false;
jtf.Run(async delegate
{
Task<int> lazyValue = asyncLazy.GetValueAsync();
Assert.True(fireAndForgetCodeAsyncEntered);
await Task.Delay(AsyncDelay);
Assert.False(fireAndForgetCodeAsyncReachedUIThread);
allowValueFactoryToFinish.Set();

// Assert that the value factory was allowed to finish.
Assert.Equal(1, await lazyValue.WithCancellation(this.TimeoutToken));
});

// Run a main thread pump so the fire-and-forget task can finish.
SingleThreadedTestSynchronizationContext.PushFrame(SynchronizationContext.Current!, frame);

// Assert that the fire-and-forget task was allowed to finish and did so without throwing.
Assert.Equal(1, await fireAndForgetTask!.WithCancellation(this.TimeoutToken));

async Task<int> FireAndForgetCodeAsync()
{
fireAndForgetCodeAsyncEntered = true;

// Yield the caller's thread.
// Resuming will require the main thread, since the caller was on the main thread.
await Task.Yield();

fireAndForgetCodeAsyncReachedUIThread = true;

int result = await asyncLazy.GetValueAsync();
frame.Continue = false;
return result;
}
}

[Fact]
public async Task Dispose_ValueType_Completed()
{
Expand Down