Skip to content

Commit

Permalink
Add Task.WhenAny(task, task) overload (#34288)
Browse files Browse the repository at this point in the history
Currently internal and used as an implementation detail under Task.WhenAny(params Task[]) as well as from SemaphoreSlim.  Once API reviewed, it can be made public.
  • Loading branch information
stephentoub authored Apr 9, 2020
1 parent b964c14 commit f034c05
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ private async Task<bool> WaitUntilCountOrTimeoutAsync(TaskNode asyncWaiter, int
// cancel, and we chain the caller's supplied token into it.
using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
{
if (asyncWaiter == await TaskFactory.CommonCWAnyLogic(new Task[] { asyncWaiter, Task.Delay(millisecondsTimeout, cts.Token) }).ConfigureAwait(false))
if (asyncWaiter == await Task.WhenAny(asyncWaiter, Task.Delay(millisecondsTimeout, cts.Token)).ConfigureAwait(false))
{
cts.Cancel(); // ensure that the Task.Delay task is cleaned up
return true; // successfully acquired
Expand All @@ -731,7 +731,7 @@ private async Task<bool> WaitUntilCountOrTimeoutAsync(TaskNode asyncWaiter, int
var cancellationTask = new Task(null, TaskCreationOptions.RunContinuationsAsynchronously, promiseStyle: true);
using (cancellationToken.UnsafeRegister(s => ((Task)s!).TrySetResult(), cancellationTask))
{
if (asyncWaiter == await TaskFactory.CommonCWAnyLogic(new Task[] { asyncWaiter, cancellationTask }).ConfigureAwait(false))
if (asyncWaiter == await Task.WhenAny(asyncWaiter, cancellationTask).ConfigureAwait(false))
{
return true; // successfully acquired
}
Expand Down
143 changes: 133 additions & 10 deletions src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Task.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4295,14 +4295,7 @@ internal void ContinueWithCore(Task continuationTask,
// Adds a lightweight completion action to a task. This is similar to a continuation
// task except that it is stored as an action, and thus does not require the allocation/
// execution resources of a continuation task.
//
// Used internally by ContinueWhenAll() and ContinueWhenAny().
internal void AddCompletionAction(ITaskCompletionAction action)
{
AddCompletionAction(action, addBeforeOthers: false);
}

internal void AddCompletionAction(ITaskCompletionAction action, bool addBeforeOthers)
internal void AddCompletionAction(ITaskCompletionAction action, bool addBeforeOthers = false)
{
if (!AddTaskContinuation(action, addBeforeOthers))
action.Invoke(this); // run the action directly if we failed to queue the continuation (i.e., the task completed)
Expand Down Expand Up @@ -5956,7 +5949,16 @@ public void Invoke(Task ignored)
/// </exception>
public static Task<Task> WhenAny(params Task[] tasks)
{
if (tasks == null) ThrowHelper.ThrowArgumentNullException(ExceptionArgument.tasks);
if (tasks == null)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.tasks);
}

if (tasks.Length == 2)
{
return WhenAny(tasks[0], tasks[1]);
}

if (tasks.Length == 0)
{
ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_EmptyTaskList, ExceptionArgument.tasks);
Expand All @@ -5977,6 +5979,104 @@ public static Task<Task> WhenAny(params Task[] tasks)
return TaskFactory.CommonCWAnyLogic(tasksCopy);
}

// TODO https://github.com/dotnet/runtime/issues/23021: Make this public.
/// <summary>Creates a task that will complete when either of the supplied tasks have completed.</summary>
/// <param name="task1">The first task to wait on for completion.</param>
/// <param name="task2">The second task to wait on for completion.</param>
/// <returns>A task that represents the completion of one of the supplied tasks. The return Task's Result is the task that completed.</returns>
/// <remarks>
/// The returned task will complete when any of the supplied tasks has completed. The returned task will always end in the RanToCompletion state
/// with its Result set to the first task to complete. This is true even if the first task to complete ended in the Canceled or Faulted state.
/// </remarks>
/// <exception cref="System.ArgumentNullException">
/// The <paramref name="task1"/> or <paramref name="task2"/> argument was null.
/// </exception>
internal static Task<Task> WhenAny(Task task1, Task task2) =>
(task1 is null) || (task2 is null) ? throw new ArgumentNullException(task1 is null ? nameof(task1) : nameof(task2)) :
task1.IsCompleted ? FromResult(task1) :
task2.IsCompleted ? FromResult(task2) :
new TwoTaskWhenAnyPromise<Task>(task1, task2);

/// <summary>A promise type used by WhenAny to wait on exactly two tasks.</summary>
/// <typeparam name="TTask">Specifies the type of the task.</typeparam>
/// <remarks>
/// This has essentially the same logic as <see cref="TaskFactory.CompleteOnInvokePromise"/>, but optimized
/// for two tasks rather than any number. Exactly two tasks has shown to be the most common use-case by far.
/// </remarks>
private sealed class TwoTaskWhenAnyPromise<TTask> : Task<TTask>, ITaskCompletionAction where TTask : Task
{
private TTask? _task1, _task2;

/// <summary>Instantiate the promise and register it with both tasks as a completion action.</summary>
public TwoTaskWhenAnyPromise(TTask task1, TTask task2)
{
Debug.Assert(task1 != null && task2 != null);
_task1 = task1;
_task2 = task2;

if (AsyncCausalityTracer.LoggingOn)
{
AsyncCausalityTracer.TraceOperationCreation(this, "Task.WhenAny");
}

if (s_asyncDebuggingEnabled)
{
AddToActiveTasks(this);
}

task1.AddCompletionAction(this);

task2.AddCompletionAction(this);
if (task1.IsCompleted)
{
// If task1 has already completed, Invoke may have tried to remove the continuation from
// each task before task2 added the continuation, in which case it's now referencing the
// already completed continuation. To deal with that race condition, explicitly check
// and remove the continuation here.
task2.RemoveContinuation(this);
}
}

/// <summary>Completes this task when one of the constituent tasks completes.</summary>
public void Invoke(Task completingTask)
{
Task? task1;
if ((task1 = Interlocked.Exchange(ref _task1, null)) != null)
{
Task? task2 = _task2;
_task2 = null;

Debug.Assert(task1 != null && task2 != null);
Debug.Assert(task1.IsCompleted || task2.IsCompleted);

if (AsyncCausalityTracer.LoggingOn)
{
AsyncCausalityTracer.TraceOperationRelation(this, CausalityRelation.Choice);
AsyncCausalityTracer.TraceOperationCompletion(this, AsyncCausalityStatus.Completed);
}

if (s_asyncDebuggingEnabled)
{
RemoveFromActiveTasks(this);
}

if (!task1.IsCompleted)
{
task1.RemoveContinuation(this);
}
else
{
task2.RemoveContinuation(this);
}

bool success = TrySetResult((TTask)completingTask);
Debug.Assert(success, "Only one task should have gotten to this point, and thus this must be successful.");
}
}

public bool InvokeMayRunArbitraryCode => true;
}

/// <summary>
/// Creates a task that will complete when any of the supplied tasks have completed.
/// </summary>
Expand Down Expand Up @@ -6035,14 +6135,37 @@ public static Task<Task<TResult>> WhenAny<TResult>(params Task<TResult>[] tasks)
// return (Task<Task<TResult>>) WhenAny( (Task[]) tasks);
// but classes are not covariant to enable casting Task<TResult> to Task<Task<TResult>>.

if (tasks != null && tasks.Length == 2)
{
return WhenAny(tasks[0], tasks[1]);
}

// Call WhenAny(Task[]) for basic functionality
Task<Task> intermediate = WhenAny((Task[])tasks);
Task<Task> intermediate = WhenAny((Task[])tasks!);

// Return a continuation task with the correct result type
return intermediate.ContinueWith(Task<TResult>.TaskWhenAnyCast.Value, default,
TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.DenyChildAttach, TaskScheduler.Default);
}

// TODO https://github.com/dotnet/runtime/issues/23021: Make this public.
/// <summary>Creates a task that will complete when either of the supplied tasks have completed.</summary>
/// <param name="task1">The first task to wait on for completion.</param>
/// <param name="task2">The second task to wait on for completion.</param>
/// <returns>A task that represents the completion of one of the supplied tasks. The return Task's Result is the task that completed.</returns>
/// <remarks>
/// The returned task will complete when any of the supplied tasks has completed. The returned task will always end in the RanToCompletion state
/// with its Result set to the first task to complete. This is true even if the first task to complete ended in the Canceled or Faulted state.
/// </remarks>
/// <exception cref="System.ArgumentNullException">
/// The <paramref name="task1"/> or <paramref name="task2"/> argument was null.
/// </exception>
internal static Task<Task<TResult>> WhenAny<TResult>(Task<TResult> task1, Task<TResult> task2) =>
(task1 is null) || (task2 is null) ? throw new ArgumentNullException(task1 is null ? nameof(task1) : nameof(task2)) :
task1.IsCompleted ? FromResult(task1) :
task2.IsCompleted ? FromResult(task2) :
new TwoTaskWhenAnyPromise<Task<TResult>>(task1, task2);

/// <summary>
/// Creates a task that will complete when any of the supplied tasks have completed.
/// </summary>
Expand Down

0 comments on commit f034c05

Please sign in to comment.