diff --git a/Source/SuperLinq.Async/Amb.cs b/Source/SuperLinq.Async/Amb.cs new file mode 100644 index 00000000..793ad653 --- /dev/null +++ b/Source/SuperLinq.Async/Amb.cs @@ -0,0 +1,190 @@ +using System.Runtime.ExceptionServices; + +namespace SuperLinq.Async; + +public static partial class AsyncSuperEnumerable +{ + /// + /// Propagates the async-enumerable sequence that reacts first. + /// + /// The type of the elements of the source sequences + /// The first sequence to merge together + /// The other sequences to merge together + /// An async-enumerable sequence that surfaces whichever sequence returned first. + /// , , or any of + /// the items in is . + /// + /// + /// The implementation of this method is deeply unfair with regards to the ordering of the input sequences. The + /// sequences are initialized in the order in which they are received. This means that earlier sequences will have + /// an opportunity to finish sooner, meaning that all other things being equal, the earlier a sequence is (where + /// precedes any sequence in ), the more likely it will be + /// chosen by this operator. Additionally, the first sequence to return the first element of the sequence + /// synchronously will be chosen. + /// + /// + public static IAsyncEnumerable Amb( + this IAsyncEnumerable source, + params IAsyncEnumerable[] otherSources) + { + Guard.IsNotNull(source); + Guard.IsNotNull(otherSources); + + foreach (var s in otherSources) + Guard.IsNotNull(s, nameof(otherSources)); + + return Amb(otherSources.Prepend(source)); + } + + /// + /// Propagates the async-enumerable sequence that reacts first. + /// + /// The type of the elements of the source sequences + /// The sequence of sequences to merge together + /// A sequence of every element from all source sequences, returned in an order based on how long it takes + /// to iterate each element. + /// or any of the items in is . + /// + /// + /// The implementation of this method is deeply unfair with regards to the ordering of the . The sequences in are initialized in the order in which they are + /// received. This means that earlier sequences will have an opportunity to finish sooner, meaning that all other + /// things being equal, the earlier a sequence is in , the more likely it will be chosen + /// by this operator. Additionally, the first sequence to return the first element of the sequence synchronously + /// will be chosen. + /// + /// + public static IAsyncEnumerable Amb( + this IEnumerable> sources) + { + Guard.IsNotNull(sources); + + return Core(sources); + + static async IAsyncEnumerable Core( + IEnumerable> sources, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var cancellationSources = new List(); + var enumerators = new List>(); + var tasks = new List>(); + + IAsyncEnumerator? e = default; + CancellationTokenSource? eCts = default; + try + { + foreach (var s in sources) + { +#pragma warning disable CA2000 // Dispose objects before losing scope + // these will be disposed later + var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); +#pragma warning restore CA2000 // Dispose objects before losing scope + + var iter = s.GetAsyncEnumerator(cts.Token); + + var firstMove = iter.MoveNextAsync(); + if (firstMove.IsCompleted) + { + // if the sequence returned the first element synchronously, then it is obviously "first", so + // choose it and forget the rest. we do not add this iter to the lists since those only track + // the items that need to be canceled and disposed. + e = iter; + + // if the selected sequence is empty, then the amb sequence is empty as well. + if (!firstMove.Result) + yield break; + + break; + } + + // async; add it to the list + cancellationSources.Add(cts); + enumerators.Add(iter); + tasks.Add(firstMove.AsTask()); + } + + if (e == null) + { + // who finishes first? + var t = await Task.WhenAny(tasks).ConfigureAwait(false); + var moveNext = await t.ConfigureAwait(false); + + // since we built all three lists simultaneously, we can access the same index of each. + // we need the enumerator (to continue to enumerate it) and the cts (to dispose it at the end) + var idx = tasks.IndexOf(t); + e = enumerators[idx]; + eCts = cancellationSources[idx]; + + // remove the selected item from the list of still-running iterators + cancellationSources.RemoveAt(idx); + enumerators.RemoveAt(idx); + tasks.RemoveAt(idx); + + // if the selected sequence is empty, then the amb sequence is empty as well. + if (!moveNext) + yield break; + } + } + finally + { + // give each still-running task a chance to bail early + foreach (var cts in cancellationSources) + cts.Cancel(); + +#pragma warning disable CA1031 // Do not catch general exception types + ExceptionDispatchInfo? edi = null; + try + { + _ = await Task.WhenAll(tasks).ConfigureAwait(false); + } + // because we canceled the cts, we might get OperationCanceledException; we don't actually care about + // these because we're intentionally cancelling them. + catch (Exception ex) when ( + ex is OperationCanceledException + || (ex is AggregateException ae && ae.InnerExceptions.All(e => e is OperationCanceledException))) + { } + // if we're in the normal path, then e != null; in this case, we need to report any exceptions that we + // encounter. + catch (Exception ex) when (e != null) + { + edi = ExceptionDispatchInfo.Capture(ex); + } + // on the other hand, if e == null, then we silently ignore any exceptions, so that the original + // exception can propagate normally. this matches the behavior of await Task.WhenAll which only throws + // the first exception it encounters. + catch { } + + foreach (var en in enumerators) + { + try + { + await en.DisposeAsync().ConfigureAwait(false); + } + // don't worry about any exceptions while disposing - theoretically these should be fast and + // error-free, but just in case... + catch { } + } +#pragma warning restore CA1031 // Do not catch general exception types + + edi?.Throw(); + + // properly dispose of the sources + foreach (var cts in cancellationSources) + cts.Dispose(); + } + + try + { + yield return e.Current; + while (await e.MoveNextAsync().ConfigureAwait(false)) + yield return e.Current; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + eCts?.Dispose(); + } + } + } +} diff --git a/Tests/SuperLinq.Async.Test/AmbTest.cs b/Tests/SuperLinq.Async.Test/AmbTest.cs new file mode 100644 index 00000000..b2db10d5 --- /dev/null +++ b/Tests/SuperLinq.Async.Test/AmbTest.cs @@ -0,0 +1,64 @@ +namespace Test.Async; + +public class AmbTest +{ + [Fact] + public void AmbIsLazy() + { + _ = new AsyncBreakingSequence().Amb(new AsyncBreakingSequence()); + _ = AsyncSuperEnumerable.Amb(new AsyncBreakingSequence(), new AsyncBreakingSequence()); + } + + [Theory] + [InlineData(1)] + [InlineData(2)] + [InlineData(3)] + public async Task AmbSyncReturnsFirst(int sequenceNumber) + { + var sync = AsyncEnumerable.Range(1, 5); + var async = AsyncEnumerable.Range(6, 5) + .SelectAwaitWithCancellation(async (i, ct) => + { + await Task.Delay(10, ct); + return i; + }); + await using var seq1 = (sequenceNumber == 1 ? sync : async).AsTestingSequence(); + await using var seq2 = (sequenceNumber == 2 ? sync : async).AsTestingSequence(); + await using var seq3 = (sequenceNumber == 3 ? sync : async).AsTestingSequence(); + + var ts = new[] { seq1, seq2, seq3, }; + + var result = ts.Amb(); + + await result.AssertSequenceEqual(Enumerable.Range(1, 5)); + } + + [Theory] + [InlineData(1)] + [InlineData(2)] + [InlineData(3)] + public async Task AmbAsyncShortestComesFirst(int sequenceNumber) + { + var shorter = AsyncEnumerable.Range(1, 5) + .SelectAwaitWithCancellation(async (i, ct) => + { + await Task.Delay(10, ct); + return i; + }); + var longer = AsyncEnumerable.Range(6, 5) + .SelectAwaitWithCancellation(async (i, ct) => + { + await Task.Delay(30, ct); + return i; + }); + await using var seq1 = (sequenceNumber == 1 ? shorter : longer).AsTestingSequence(); + await using var seq2 = (sequenceNumber == 2 ? shorter : longer).AsTestingSequence(); + await using var seq3 = (sequenceNumber == 3 ? shorter : longer).AsTestingSequence(); + + var ts = new[] { seq1, seq2, seq3, }; + + var result = ts.Amb(); + + await result.AssertSequenceEqual(Enumerable.Range(1, 5)); + } +}