diff --git a/Source/SuperLinq.Async/Publish.cs b/Source/SuperLinq.Async/Publish.cs new file mode 100644 index 00000000..3ff43c0a --- /dev/null +++ b/Source/SuperLinq.Async/Publish.cs @@ -0,0 +1,260 @@ +using System.Collections; +using System.Runtime.ExceptionServices; + +namespace SuperLinq.Async; + +public static partial class AsyncSuperEnumerable +{ + /// + /// Creates a buffer with a view over the source sequence, causing each enumerator to obtain access to the remainder + /// of the sequence from the current index in the buffer. + /// + /// Source sequence element type. + /// Source sequence. + /// + /// Buffer enabling each enumerator to retrieve elements from the shared source sequence, starting from the index at + /// the point of obtaining the enumerator. + /// + /// is . + public static IAsyncBuffer Publish(this IAsyncEnumerable source) + { + Guard.IsNotNull(source); + + return new PublishBuffer(source); + } + + /// + /// Publishes the source sequence within a selector function where each enumerator can obtain a view over a tail of + /// the source sequence. + /// + /// Source sequence element type. + /// Result sequence element type. + /// Source sequence. + /// Selector function with published access to the source sequence for each + /// enumerator. + /// Sequence resulting from applying the selector function to the published view over the source + /// sequence. + /// or is . + public static IAsyncEnumerable Publish( + this IAsyncEnumerable source, + Func, IAsyncEnumerable> selector) + { + Guard.IsNotNull(source); + Guard.IsNotNull(selector); + + return Core(source, selector); + + static async IAsyncEnumerable Core( + IAsyncEnumerable source, + Func, IAsyncEnumerable> selector, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await using var buffer = source.Publish(); + await foreach (var i in selector(buffer).WithCancellation(cancellationToken).ConfigureAwait(false)) + yield return i; + } + } + + private sealed class PublishBuffer : IAsyncBuffer + { + private readonly SemaphoreSlim _lock = new(initialCount: 1); + + private IAsyncEnumerable? _source; + + private IAsyncEnumerator? _enumerator; + private List>? _buffers; + private bool _initialized; + private int _version; + + private ExceptionDispatchInfo? _exception; + private bool? _exceptionOnGetEnumerator; + + private bool _disposed; + + public PublishBuffer(IAsyncEnumerable source) + { + _source = source; + } + + public int Count => _buffers?.Count > 0 ? _buffers.Max(x => x.Count) : 0; + + public async ValueTask Reset(CancellationToken cancellationToken = default) + { + if (_disposed) + ThrowHelper.ThrowObjectDisposedException(nameof(IAsyncBuffer)); + + await _lock.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + if (_disposed) + ThrowHelper.ThrowObjectDisposedException(nameof(IAsyncBuffer)); + + _initialized = false; + _version++; + + _buffers = null; + + if (_enumerator != null) + await _enumerator.DisposeAsync(); + _enumerator = null; + _exception = null; + _exceptionOnGetEnumerator = null; + } + finally + { + _ = _lock.Release(); + } + } + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + var (buffer, version) = InitializeEnumerator(cancellationToken); + return GetEnumeratorImpl(buffer, version, cancellationToken); + } + + private (Queue buffer, int version) InitializeEnumerator(CancellationToken cancellationToken) + { + if (_disposed) + ThrowHelper.ThrowObjectDisposedException(nameof(IAsyncBuffer)); + + _lock.Wait(cancellationToken); + try + { + if (_disposed) + ThrowHelper.ThrowObjectDisposedException(nameof(IAsyncBuffer)); + + Assert.NotNull(_source); + + if (_exceptionOnGetEnumerator == true) + { + Assert.NotNull(_exception); + _exception.Throw(); + } + + if (!_initialized) + { + try + { + _enumerator = _source.GetAsyncEnumerator(cancellationToken); + _buffers = new(); + _initialized = true; + } + catch (Exception ex) + { + _exception = ExceptionDispatchInfo.Capture(ex); + _exceptionOnGetEnumerator = true; + throw; + } + } + + Assert.NotNull(_buffers); + + var queue = new Queue(); + _buffers.Add(queue); + return (queue, _version); + } + finally + { + _ = _lock.Release(); + } + } + + private async IAsyncEnumerator GetEnumeratorImpl(Queue buffer, int version, CancellationToken cancellationToken) + { + try + { + while (true) + { + T? element; + + if (_disposed) + ThrowHelper.ThrowObjectDisposedException(nameof(IAsyncBuffer)); + + await _lock.WaitAsync(cancellationToken); + try + { + if (_disposed) + ThrowHelper.ThrowObjectDisposedException(nameof(IBuffer)); + if (!_initialized + || version != _version) + { + ThrowHelper.ThrowInvalidOperationException("Buffer reset during iteration."); + } + + if (buffer.Count == 0) + { + _exception?.Throw(); + + if (_enumerator == null) + break; + + var moved = false; + try + { + moved = await _enumerator.MoveNextAsync(cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + _exception = ExceptionDispatchInfo.Capture(ex); + _exceptionOnGetEnumerator = false; + await _enumerator.DisposeAsync().ConfigureAwait(false); + _enumerator = null; + throw; + } + + if (!moved) + { + await _enumerator.DisposeAsync().ConfigureAwait(false); + _enumerator = null; + break; + } + + Assert.NotNull(_buffers); + + var current = _enumerator.Current; + foreach (var q in _buffers) + q.Enqueue(current); + } + + element = buffer.Dequeue(); + } + finally + { + _ = _lock.Release(); + } + + yield return element; + } + } + finally + { + _ = _buffers?.Remove(buffer); + } + } + + public async ValueTask DisposeAsync() + { + if (_disposed) + return; + + await _lock.WaitAsync().ConfigureAwait(false); + try + { + _disposed = true; + + _buffers = null; + + if (_enumerator != null) + await _enumerator.DisposeAsync().ConfigureAwait(false); + _enumerator = null; + _source = null; + } + finally + { + _ = _lock.Release(); + _lock.Dispose(); + } + } + } +} diff --git a/Tests/SuperLinq.Async.Test/PublishTest.cs b/Tests/SuperLinq.Async.Test/PublishTest.cs new file mode 100644 index 00000000..067279bc --- /dev/null +++ b/Tests/SuperLinq.Async.Test/PublishTest.cs @@ -0,0 +1,307 @@ +using CommunityToolkit.Diagnostics; + +namespace Test.Async; + +public class PublishTest +{ + [Fact] + public void PublishIsLazy() + { + _ = new AsyncBreakingSequence().Publish(); + } + + [Fact] + public async Task PublishWithSingleConsumer() + { + await using var seq = Enumerable.Range(1, 10).AsTestingSequence(); + + await using var result = seq.Publish(); + await result.AssertSequenceEqual(Enumerable.Range(1, 10)); + } + + [Fact] + public async Task PublishWithMultipleConsumers() + { + await using var seq = Enumerable.Range(1, 10).AsTestingSequence(); + + await using var result = seq.Publish(); + + await using var r1 = result.Read(); + Assert.Equal(1, await r1.Read()); + Assert.Equal(2, await r1.Read()); + Assert.Equal(0, result.Count); + + await using var r2 = result.Read(); + Assert.Equal(3, await r1.Read()); + Assert.Equal(3, await r2.Read()); + Assert.Equal(0, result.Count); + + Assert.Equal(4, await r1.Read()); + Assert.Equal(5, await r1.Read()); + Assert.Equal(6, await r1.Read()); + Assert.Equal(7, await r1.Read()); + Assert.Equal(4, result.Count); + + Assert.Equal(4, await r2.Read()); + Assert.Equal(5, await r2.Read()); + Assert.Equal(2, result.Count); + + Assert.Equal(8, await r1.Read()); + Assert.Equal(9, await r1.Read()); + Assert.Equal(4, result.Count); + + await using var r3 = result.Read(); + Assert.Equal(10, await r3.Read()); + await r3.ReadEnd(); + Assert.Equal(5, result.Count); + + Assert.Equal(6, await r2.Read()); + Assert.Equal(7, await r2.Read()); + Assert.Equal(3, result.Count); + + Assert.Equal(10, await r1.Read()); + await r1.ReadEnd(); + + Assert.Equal(8, await r2.Read()); + Assert.Equal(9, await r2.Read()); + Assert.Equal(10, await r2.Read()); + await r2.ReadEnd(); + Assert.Equal(0, result.Count); + } + + [Fact] + public async Task PublishWithInnerConsumer() + { + await using var seq = Enumerable.Range(1, 6).AsTestingSequence(); + + await using var result = seq.Publish(); + + await using var r1 = result.Read(); + Assert.Equal(1, await r1.Read()); + Assert.Equal(2, await r1.Read()); + + await using (var r2 = result.Read()) + { + Assert.Equal(3, await r2.Read()); + Assert.Equal(4, await r2.Read()); + Assert.Equal(2, result.Count); + } + + Assert.Equal(3, await r1.Read()); + Assert.Equal(4, await r1.Read()); + Assert.Equal(5, await r1.Read()); + Assert.Equal(6, await r1.Read()); + + await r1.ReadEnd(); + Assert.Equal(0, result.Count); + } + + [Fact] + public async Task PublishWithSequentialPartialConsumers() + { + await using var seq = Enumerable.Range(1, 10).AsTestingSequence(); + + await using var result = seq.Publish(); + + await using (var r1 = result.Read()) + { + Assert.Equal(1, await r1.Read()); + Assert.Equal(2, await r1.Read()); + Assert.Equal(3, await r1.Read()); + Assert.Equal(4, await r1.Read()); + Assert.Equal(5, await r1.Read()); + Assert.Equal(0, result.Count); + } + + await using (var r2 = result.Read()) + { + Assert.Equal(6, await r2.Read()); + Assert.Equal(7, await r2.Read()); + Assert.Equal(8, await r2.Read()); + Assert.Equal(9, await r2.Read()); + Assert.Equal(10, await r2.Read()); + await r2.ReadEnd(); + Assert.Equal(0, result.Count); + } + + await using var r3 = result.Read(); + await r3.ReadEnd(); + Assert.Equal(0, result.Count); + } + + [Fact] + public async Task PublishDisposesAfterSourceIsIteratedEntirely() + { + await using var seq = Enumerable.Range(0, 10).AsTestingSequence(); + + await using var buffer = seq.Publish(); + await buffer.Consume(); + + Assert.True(seq.IsDisposed); + } + + [Fact] + public async Task PublishDisposesWithPartialEnumeration() + { + await using var seq = Enumerable.Range(0, 10).AsTestingSequence(); + + await using var buffer = seq.Publish(); + + await using (buffer) + await buffer.Take(5).Consume(); + + Assert.True(seq.IsDisposed); + } + + [Fact] + public async Task PublishRestartsAfterReset() + { + var starts = 0; + + IEnumerable TestSequence() + { + starts++; + yield return 1; + yield return 2; + } + + await using var seq = TestSequence().AsTestingSequence(maxEnumerations: 2); + await using var buffer = seq.Publish(); + + await buffer.Take(1).Consume(); + Assert.Equal(1, starts); + + await buffer.Reset(); + await buffer.Take(1).Consume(); + Assert.Equal(2, starts); + } + + [Fact] + public async Task PublishThrowsWhenCacheDisposedDuringIteration() + { + await using var seq = Enumerable.Range(0, 10).AsTestingSequence(); + await using var buffer = seq.Publish(); + + await using var reader = buffer.Read(); + + Assert.Equal(0, await reader.Read()); + await buffer.DisposeAsync(); + + _ = await Assert.ThrowsAsync( + async () => await reader.Read()); + } + + [Fact] + public async Task PublishThrowsWhenResetDuringIteration() + { + await using var seq = Enumerable.Range(0, 10).AsTestingSequence(); + await using var buffer = seq.Publish(); + + await using var reader = buffer.Read(); + + Assert.Equal(0, await reader.Read()); + await buffer.Reset(); + + var ex = await Assert.ThrowsAsync( + async () => await reader.Read()); + Assert.Equal("Buffer reset during iteration.", ex.Message); + } + + [Fact] + public async Task PublishThrowsWhenGettingIteratorAfterDispose() + { + await using var seq = Enumerable.Range(0, 10).AsTestingSequence(); + await using var buffer = seq.Publish(); + await buffer.Consume(); + await buffer.DisposeAsync(); + + _ = await Assert.ThrowsAsync( + async () => await buffer.Consume()); + } + + [Fact] + public async Task PublishThrowsWhenResettingAfterDispose() + { + await using var seq = Enumerable.Range(0, 10).AsTestingSequence(); + await using var buffer = seq.Publish(); + await buffer.Consume(); + await buffer.DisposeAsync(); + + _ = await Assert.ThrowsAsync( + async () => await buffer.Reset()); + } + + [Fact] + public async Task PublishRethrowsErrorDuringIterationToAllIteratorsUntilReset() + { + await using var xs = AsyncSeqExceptionAt(2).AsTestingSequence(maxEnumerations: 2); + + await using var buffer = xs.Publish(); + + await using (var r1 = buffer.Read()) + await using (var r2 = buffer.Read()) + { + Assert.Equal(1, await r1.Read()); + Assert.Equal(1, await r2.Read()); + + _ = await Assert.ThrowsAsync(async () => await r1.Read()); + _ = await Assert.ThrowsAsync(async () => await r2.Read()); + Guard.IsTrue(xs.IsDisposed); + } + + await using (var r3 = buffer.Read()) + _ = await Assert.ThrowsAsync(async () => await r3.Read()); + + await buffer.Reset(); + + await using var r4 = buffer.Read(); + Assert.Equal(1, await r4.Read()); + } + + [Fact] + public async Task PublishRethrowsErrorDuringFirstIterationStartToAllIterationsUntilReset() + { + await using var seq = new FailingEnumerable().AsTestingSequence(maxEnumerations: 2); + + await using var buffer = seq.Publish(); + + for (var i = 0; i < 2; i++) + { + _ = await Assert.ThrowsAsync( + async () => await buffer.FirstAsync()); + } + + await buffer.Reset(); + await buffer.AssertSequenceEqual(1); + } + + private class FailingEnumerable : IAsyncEnumerable + { + private bool _started; + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + if (!_started) + { + _started = true; + throw new TestException(); + } + return AsyncEnumerable.Range(1, 1).GetAsyncEnumerator(cancellationToken); + } + } + + [Fact] + public void PublishLambdaIsLazy() + { + _ = new AsyncBreakingSequence().Publish(BreakingFunc.Of, IAsyncEnumerable>()); + } + + [Fact] + public async Task PublishLambdaSimple() + { + await using var seq = Enumerable.Range(1, 10).AsTestingSequence(); + + var result = seq.Publish(xs => xs.Zip(xs, (l, r) => l + r).Take(4)); + await result.AssertSequenceEqual(2, 4, 6, 8); + } +}