diff --git a/src/System.IO.Pipelines/src/System/IO/Pipelines/StreamPipeReader.cs b/src/System.IO.Pipelines/src/System/IO/Pipelines/StreamPipeReader.cs index 58e3e5c90f77..51ca4fbca2f1 100644 --- a/src/System.IO.Pipelines/src/System/IO/Pipelines/StreamPipeReader.cs +++ b/src/System.IO.Pipelines/src/System/IO/Pipelines/StreamPipeReader.cs @@ -4,12 +4,14 @@ using System.Buffers; using System.Diagnostics; +using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; +using System.Threading.Tasks.Sources; namespace System.IO.Pipelines { - internal class StreamPipeReader : PipeReader + internal class StreamPipeReader : PipeReader, IValueTaskSource { internal const int InitialSegmentPoolSize = 4; // 16K internal const int MaxSegmentPoolSize = 256; // 1MB @@ -18,16 +20,18 @@ internal class StreamPipeReader : PipeReader private readonly int _minimumReadThreshold; private readonly MemoryPool _pool; - private CancellationTokenSource _internalTokenSource; private bool _isReaderCompleted; private bool _isStreamCompleted; + private PipeAwaitable _awaitable; + private Task _streamReadTask = Task.CompletedTask; + private ExceptionDispatchInfo _edi; + private readonly CancellationTokenSource _completeCts = new CancellationTokenSource(); private BufferSegment _readHead; private int _readIndex; private BufferSegment _readTail; private long _bufferedBytes; - private bool _examinedEverything; private object _lock = new object(); // Mutable struct! Don't make this readonly @@ -53,6 +57,7 @@ public StreamPipeReader(Stream readingStream, StreamPipeReaderOptions options) _pool = options.Pool == MemoryPool.Shared ? null : options.Pool; _bufferSize = _pool == null ? options.BufferSize : Math.Min(options.BufferSize, _pool.MaxBufferSize); _leaveOpen = options.LeaveOpen; + _awaitable = new PipeAwaitable(completed: false, useSynchronizationContext: true); } /// @@ -66,21 +71,6 @@ public override void AdvanceTo(SequencePosition consumed) AdvanceTo(consumed, consumed); } - private CancellationTokenSource InternalTokenSource - { - get - { - lock (_lock) - { - if (_internalTokenSource == null) - { - _internalTokenSource = new CancellationTokenSource(); - } - return _internalTokenSource; - } - } - } - /// public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) { @@ -91,7 +81,7 @@ public override void AdvanceTo(SequencePosition consumed, SequencePosition exami private void AdvanceTo(BufferSegment consumedSegment, int consumedIndex, BufferSegment examinedSegment, int examinedIndex) { - if (consumedSegment == null || examinedSegment == null) + if (consumedSegment == null || examinedSegment == null || !_streamReadTask.IsCompleted) { return; } @@ -110,13 +100,13 @@ private void AdvanceTo(BufferSegment consumedSegment, int consumedIndex, BufferS Debug.Assert(_bufferedBytes >= 0); - _examinedEverything = false; + var examinedEverything = false; if (examinedSegment == _readTail) { // If we examined everything, we force ReadAsync to actually read from the underlying stream // instead of returning a ReadResult from TryRead. - _examinedEverything = examinedIndex == _readTail.End; + examinedEverything = examinedIndex == _readTail.End; } // Two cases here: @@ -154,12 +144,23 @@ private void AdvanceTo(BufferSegment consumedSegment, int consumedIndex, BufferS ReturnSegmentUnsynchronized(returnStart); returnStart = next; } + + if (examinedEverything) + { + _awaitable.SetUncompleted(); + } } /// public override void CancelPendingRead() { - InternalTokenSource.Cancel(); + CompletionData completionData; + lock (_lock) + { + _awaitable.Cancel(out completionData); + } + + DispatchCompletion(completionData); } /// @@ -172,6 +173,15 @@ public override void Complete(Exception exception = null) _isReaderCompleted = true; + // Make an attempt to cancel any call to Stream.ReadAsync + _completeCts.Cancel(); + + if (!_leaveOpen) + { + InnerStream.Dispose(); + } + + // Return the memory after potentially disposing the stream BufferSegment segment = _readHead; while (segment != null) { @@ -181,10 +191,7 @@ public override void Complete(Exception exception = null) returnSegment.ResetMemory(); } - if (!_leaveOpen) - { - InnerStream.Dispose(); - } + _completeCts.Dispose(); } /// @@ -193,39 +200,59 @@ public override void OnWriterCompleted(Action callback, objec } /// - public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) + public override ValueTask ReadAsync(CancellationToken cancellationToken = default) { // TODO ReadyAsync needs to throw if there are overlapping reads. ThrowIfCompleted(); - // PERF: store InternalTokenSource locally to avoid querying it twice (which acquires a lock) - CancellationTokenSource tokenSource = InternalTokenSource; - if (TryReadInternal(tokenSource, out ReadResult readResult)) + if (TryReadInternal(out ReadResult readResult)) { - return readResult; + return new ValueTask(readResult); } if (_isStreamCompleted) { - return new ReadResult(buffer: default, isCanceled: false, isCompleted: true); + return new ValueTask(new ReadResult(buffer: default, isCanceled: false, isCompleted: true)); + } + + if (_streamReadTask.IsCompleted) + { + _streamReadTask = ReadStreamAsync(cancellationToken); + + // Completed the stream read inline because it was synchronous and there was no exception thrown + if (_streamReadTask.IsCompleted && _edi == null) + { + return new ValueTask(GetReadResult()); + } } - var reg = new CancellationTokenRegistration(); + return new ValueTask(this, 0); + } + + private async Task ReadStreamAsync(CancellationToken cancellationToken) + { + CancellationTokenSource cts = null; + CancellationToken effectiveToken; + if (cancellationToken.CanBeCanceled) { - reg = cancellationToken.UnsafeRegister(state => ((StreamPipeReader)state).Cancel(), this); + cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _completeCts.Token); + effectiveToken = cts.Token; + } + else + { + effectiveToken = _completeCts.Token; } - using (reg) + using (cts) { - var isCanceled = false; try { AllocateReadTail(); Memory buffer = _readTail.AvailableMemory.Slice(_readTail.End); - int length = await InnerStream.ReadAsync(buffer, tokenSource.Token).ConfigureAwait(false); + int length = await InnerStream.ReadAsync(buffer, effectiveToken).ConfigureAwait(false); Debug.Assert(length + _readTail.End <= _readTail.AvailableMemory.Length); @@ -237,32 +264,20 @@ public override async ValueTask ReadAsync(CancellationToken cancella _isStreamCompleted = true; } } - catch (OperationCanceledException) + catch (Exception ex) { - ClearCancellationToken(); - - if (tokenSource.IsCancellationRequested && !cancellationToken.IsCancellationRequested) - { - // Catch cancellation and translate it into setting isCanceled = true - isCanceled = true; - } - else - { - throw; - } - + _edi = ExceptionDispatchInfo.Capture(ex); } - - return new ReadResult(GetCurrentReadOnlySequence(), isCanceled, _isStreamCompleted); } - } - private void ClearCancellationToken() - { + CompletionData completionData; + lock (_lock) { - _internalTokenSource = null; + _awaitable.Complete(out completionData); } + + DispatchCompletion(completionData); } private void ThrowIfCompleted() @@ -277,22 +292,14 @@ public override bool TryRead(out ReadResult result) { ThrowIfCompleted(); - return TryReadInternal(InternalTokenSource, out result); + return TryReadInternal(out result); } - private bool TryReadInternal(CancellationTokenSource source, out ReadResult result) + private bool TryReadInternal(out ReadResult result) { - bool isCancellationRequested = source.IsCancellationRequested; - if (isCancellationRequested || _bufferedBytes > 0 && (!_examinedEverything || _isStreamCompleted)) + if (_awaitable.IsCompleted || (_bufferedBytes > 0 && (_awaitable.IsCompleted || _isStreamCompleted))) { - if (isCancellationRequested) - { - ClearCancellationToken(); - } - - ReadOnlySequence buffer = _readHead == null ? default : GetCurrentReadOnlySequence(); - - result = new ReadResult(buffer, isCancellationRequested, _isStreamCompleted); + result = GetReadResult(); return true; } @@ -358,9 +365,59 @@ private void ReturnSegmentUnsynchronized(BufferSegment segment) } } - private void Cancel() + private ReadResult GetReadResult() + { + var isCancellationRequested = _awaitable.ObserveCancellation(); + + ReadOnlySequence buffer = _readHead == null ? default : GetCurrentReadOnlySequence(); + + return new ReadResult(buffer, isCancellationRequested, _isStreamCompleted); + } + + public ReadResult GetResult(short token) + { + ExceptionDispatchInfo edi = _edi; + _edi = null; + edi?.Throw(); + + return GetReadResult(); + } + + public ValueTaskSourceStatus GetStatus(short token) { - InternalTokenSource.Cancel(); + if (_awaitable.IsCompleted) + { + if (_edi != null) + { + return ValueTaskSourceStatus.Faulted; + } + + return ValueTaskSourceStatus.Succeeded; + } + return ValueTaskSourceStatus.Pending; + } + + public void OnCompleted(Action continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) + { + CompletionData completionData; + bool doubleCompletion; + + lock (_lock) + { + _awaitable.OnCompleted(continuation, state, flags, out completionData, out doubleCompletion); + } + + DispatchCompletion(completionData); + } + + private static void DispatchCompletion(in CompletionData completionData) + { + if (completionData.Completion is null) + { + return; + } + + PipeScheduler.ThreadPool.UnsafeSchedule(completionData.Completion, completionData.CompletionState); } } } diff --git a/src/System.IO.Pipelines/tests/StreamPipeReaderTests.cs b/src/System.IO.Pipelines/tests/StreamPipeReaderTests.cs index f60680c8fdaf..d1995017e73f 100644 --- a/src/System.IO.Pipelines/tests/StreamPipeReaderTests.cs +++ b/src/System.IO.Pipelines/tests/StreamPipeReaderTests.cs @@ -269,13 +269,16 @@ public async Task ReadCanBeCanceledViaCancelPendingReadWhenReadIsAsync() PipeReader reader = PipeReader.Create(stream); ValueTask task = reader.ReadAsync(); - reader.CancelPendingRead(); + ReadResult readResult = await task; + Assert.True(readResult.IsCanceled); + reader.AdvanceTo(readResult.Buffer.End); stream.WaitForReadTask.TrySetResult(null); - ReadResult readResult = await task; - Assert.True(readResult.IsCanceled); + readResult = await reader.ReadAsync(); + Assert.True(readResult.IsCompleted); + reader.Complete(); }