From d6e36e4ecb0d004de54eb7bdf0ed6cdcb042e3c9 Mon Sep 17 00:00:00 2001 From: Max Kerr Date: Thu, 11 Apr 2019 11:23:45 -0700 Subject: [PATCH] HTTP/2 Request Cancellation (#35118) HTTP/2 cancellation support, plus improvements to outgoing write buffering. --- .../src/System.Net.Http.csproj | 1 + .../Http/SocketsHttpHandler/ArrayBuffer.cs | 2 + .../Http/SocketsHttpHandler/CreditManager.cs | 26 +- .../SocketsHttpHandler/Http2Connection.cs | 273 +++++++++++------- .../Http/SocketsHttpHandler/Http2Stream.cs | 39 ++- .../SocketsHttpHandler/HttpConnectionPool.cs | 23 -- .../TaskCompletionSourceWithCancellation.cs | 32 ++ .../HttpClientHandlerTest.Cancellation.cs | 40 +++ .../HttpClientHandlerTest.Http2.cs | 154 ++++++++++ .../FunctionalTests/HttpClientHandlerTest.cs | 29 -- .../FunctionalTests/SocketsHttpHandlerTest.cs | 7 + 11 files changed, 451 insertions(+), 175 deletions(-) create mode 100644 src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/TaskCompletionSourceWithCancellation.cs diff --git a/src/System.Net.Http/src/System.Net.Http.csproj b/src/System.Net.Http/src/System.Net.Http.csproj index ca8b8217fce0..12c44f1d31fb 100644 --- a/src/System.Net.Http/src/System.Net.Http.csproj +++ b/src/System.Net.Http/src/System.Net.Http.csproj @@ -157,6 +157,7 @@ + Common\CoreLib\System\Collections\Concurrent\ConcurrentQueueSegment.cs diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ArrayBuffer.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ArrayBuffer.cs index 7d82ab476c19..74eee0da03c3 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ArrayBuffer.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ArrayBuffer.cs @@ -59,6 +59,8 @@ public void Dispose() public Memory ActiveMemory => new Memory(_bytes, _activeStart, _availableStart - _activeStart); public Memory AvailableMemory => new Memory(_bytes, _availableStart, _bytes.Length - _availableStart); + public int Capacity => _bytes.Length; + public void Discard(int byteCount) { Debug.Assert(byteCount <= ActiveSpan.Length, $"Expected {byteCount} <= {ActiveSpan.Length}"); diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/CreditManager.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/CreditManager.cs index 6b5f9228b103..3fd0da14c8e7 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/CreditManager.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/CreditManager.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; namespace System.Net.Http @@ -13,7 +14,7 @@ internal sealed class CreditManager : IDisposable private struct Waiter { public int Amount; - public TaskCompletionSource TaskCompletionSource; + public TaskCompletionSourceWithCancellation TaskCompletionSource; } private int _current; @@ -37,7 +38,7 @@ private object SyncObject } } - public ValueTask RequestCreditAsync(int amount) + public ValueTask RequestCreditAsync(int amount, CancellationToken cancellationToken) { lock (SyncObject) { @@ -55,16 +56,21 @@ public ValueTask RequestCreditAsync(int amount) return new ValueTask(granted); } - var tcs = new TaskCompletionSource(TaskContinuationOptions.RunContinuationsAsynchronously); + // Uses RunContinuationsAsynchronously internally. + var tcs = new TaskCompletionSourceWithCancellation(); if (_waiters == null) { _waiters = new Queue(); } - _waiters.Enqueue(new Waiter { Amount = amount, TaskCompletionSource = tcs }); + Waiter waiter = new Waiter { Amount = amount, TaskCompletionSource = tcs }; - return new ValueTask(tcs.Task); + _waiters.Enqueue(waiter); + + return new ValueTask(cancellationToken.CanBeCanceled ? + tcs.WaitWithCancellationAsync(cancellationToken) : + tcs.Task); } } @@ -92,8 +98,12 @@ public void AdjustCredit(int amount) while (_current > 0 && _waiters.TryDequeue(out Waiter waiter)) { int granted = Math.Min(waiter.Amount, _current); - _current -= granted; - waiter.TaskCompletionSource.SetResult(granted); + + // Ensure that we grant credit only if the task has not been canceled. + if (waiter.TaskCompletionSource.TrySetResult(granted)) + { + _current -= granted; + } } } } @@ -114,7 +124,7 @@ public void Dispose() { while (_waiters.TryDequeue(out Waiter waiter)) { - waiter.TaskCompletionSource.SetException(new ObjectDisposedException(nameof(CreditManager))); + waiter.TaskCompletionSource.TrySetException(new ObjectDisposedException(nameof(CreditManager))); } } } diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index c19ac0378308..b8e46def68fb 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -31,6 +31,7 @@ internal sealed partial class Http2Connection : HttpConnectionBase, IDisposable private readonly Dictionary _httpStreams; private readonly SemaphoreSlim _writerLock; + private readonly SemaphoreSlim _headerSerializationLock; private readonly CreditManager _connectionWindow; private readonly CreditManager _concurrentStreams; @@ -41,9 +42,16 @@ internal sealed partial class Http2Connection : HttpConnectionBase, IDisposable private int _maxConcurrentStreams; private int _pendingWindowUpdate; private int _idleSinceTickCount; + private int _pendingWriters; private bool _disposed; + // If an in-progress write is canceled we need to be able to immediately + // report a cancellation to the user, but also block the connection until + // the write completes. We avoid actually canceling the write, as we would + // then have to close the whole connection. + private Task _inProgressWrite = null; + private const int MaxStreamId = int.MaxValue; private static readonly byte[] s_http2ConnectionPreface = Encoding.ASCII.GetBytes("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"); @@ -65,6 +73,11 @@ internal sealed partial class Http2Connection : HttpConnectionBase, IDisposable // rather than just increase the threshold. private const int ConnectionWindowThreshold = ConnectionWindowSize / 8; + // When buffering outgoing writes, we will automatically buffer up to this number of bytes. + // Single writes that are larger than the buffer can cause the buffer to expand beyond + // this value, so this is not a hard maximum size. + private const int UnflushedOutgoingBufferSize = 32 * 1024; + public Http2Connection(HttpConnectionPool pool, SslStream stream) { _pool = pool; @@ -78,6 +91,7 @@ public Http2Connection(HttpConnectionPool pool, SslStream stream) _httpStreams = new Dictionary(); _writerLock = new SemaphoreSlim(1, 1); + _headerSerializationLock = new SemaphoreSlim(1, 1); _connectionWindow = new CreditManager(DefaultInitialWindowSize); _concurrentStreams = new CreditManager(int.MaxValue); @@ -458,7 +472,7 @@ private void ProcessSettingsFrame(FrameHeader frameHeader) // Send acknowledgement // Don't wait for completion, which could happen asynchronously. - ValueTask ignored = SendSettingsAckAsync(); + Task ignored = SendSettingsAckAsync(); } } @@ -527,7 +541,7 @@ private void ProcessPingFrame(FrameHeader frameHeader) // Send PING ACK // Don't wait for completion, which could happen asynchronously. - ValueTask ignored = SendPingAckAsync(_incomingBuffer.ActiveMemory.Slice(0, FrameHeader.PingLength)); + Task ignored = SendPingAckAsync(_incomingBuffer.ActiveMemory.Slice(0, FrameHeader.PingLength)); _incomingBuffer.Discard(frameHeader.Length); } @@ -623,80 +637,116 @@ private void ProcessGoAwayFrame(FrameHeader frameHeader) _incomingBuffer.Discard(frameHeader.Length); } - private async ValueTask AcquireWriteLockAsync() + private async Task StartWriteAsync(int writeBytes, CancellationToken cancellationToken = default) { - await _writerLock.WaitAsync().ConfigureAwait(false); + await AcquireWriteLockAsync(cancellationToken).ConfigureAwait(false); - // If the connection has been aborted, then fail now instead of trying to send more data. - if (IsAborted()) + try { - throw new IOException(SR.net_http_invalid_response); - } - } + // If there is a pending write that was canceled while in progress, wait for it to complete. + if (_inProgressWrite != null) + { + await _inProgressWrite.ConfigureAwait(false); + _inProgressWrite = null; + } - private void ReleaseWriteLock() - { - // Currently, we always flush the write buffer before releasing the lock. - // If we change this in the future, we will need to revisit this assert. - Debug.Assert(_outgoingBuffer.ActiveMemory.IsEmpty); + int totalBufferLength = _outgoingBuffer.Capacity; + int activeBufferLength = _outgoingBuffer.ActiveSpan.Length; - _writerLock.Release(); + if (totalBufferLength >= UnflushedOutgoingBufferSize && + writeBytes >= totalBufferLength - activeBufferLength && + activeBufferLength > 0) + { + // If the buffer has already grown to 32k, does not have room for the next request, + // and is non-empty, flush the current contents to the wire. + await FlushOutgoingBytesAsync().ConfigureAwait(false); + } + + _outgoingBuffer.EnsureAvailableSpace(writeBytes); + } + catch + { + _writerLock.Release(); + throw; + } } - private async ValueTask SendSettingsAckAsync() + // This method handles flushing bytes to the wire. Writes here need to be atomic, so as to avoid + // killing the whole connection. Callers must hold the write lock, but can specify whether or not + // they want to release it. + private void FinishWrite(bool mustFlush) { - await AcquireWriteLockAsync().ConfigureAwait(false); + // We can't validate that we hold the semaphore, but we can at least validate that someone is + // holding it. + Debug.Assert(_writerLock.CurrentCount == 0); + try { - _outgoingBuffer.EnsureAvailableSpace(FrameHeader.Size); - WriteFrameHeader(new FrameHeader(0, FrameType.Settings, FrameFlags.Ack, 0)); + // We must flush if the caller requires it, or if there are no other pending writes. + if (mustFlush || _pendingWriters == 0) + { + Debug.Assert(_inProgressWrite == null); - await FlushOutgoingBytesAsync().ConfigureAwait(false); + _inProgressWrite = FlushOutgoingBytesAsync(); + } } finally { - ReleaseWriteLock(); + _writerLock.Release(); } } - private async ValueTask SendPingAckAsync(ReadOnlyMemory pingContent) + private async Task AcquireWriteLockAsync(CancellationToken cancellationToken) { - Debug.Assert(pingContent.Length == FrameHeader.PingLength); - - await AcquireWriteLockAsync().ConfigureAwait(false); - try + Task acquireLockTask = _writerLock.WaitAsync(cancellationToken); + if (!acquireLockTask.IsCompletedSuccessfully) { - _outgoingBuffer.EnsureAvailableSpace(FrameHeader.Size + FrameHeader.PingLength); - WriteFrameHeader(new FrameHeader(FrameHeader.PingLength, FrameType.Ping, FrameFlags.Ack, 0)); - pingContent.CopyTo(_outgoingBuffer.AvailableMemory); - _outgoingBuffer.Commit(FrameHeader.PingLength); - - await FlushOutgoingBytesAsync().ConfigureAwait(false); + Interlocked.Increment(ref _pendingWriters); + try + { + await acquireLockTask.ConfigureAwait(false); + } + finally + { + Interlocked.Decrement(ref _pendingWriters); + } } - finally + + // If the connection has been aborted, then fail now instead of trying to send more data. + if (IsAborted()) { - ReleaseWriteLock(); + throw new IOException(SR.net_http_invalid_response); } } - private async Task SendRstStreamAsync(int streamId, Http2ProtocolErrorCode errorCode) + private async Task SendSettingsAckAsync() { - await AcquireWriteLockAsync().ConfigureAwait(false); - try - { - _outgoingBuffer.EnsureAvailableSpace(FrameHeader.Size + FrameHeader.RstStreamLength); - WriteFrameHeader(new FrameHeader(FrameHeader.RstStreamLength, FrameType.RstStream, FrameFlags.None, streamId)); + await StartWriteAsync(FrameHeader.Size).ConfigureAwait(false); + WriteFrameHeader(new FrameHeader(0, FrameType.Settings, FrameFlags.Ack, 0)); - BinaryPrimitives.WriteInt32BigEndian(_outgoingBuffer.AvailableSpan, (int)errorCode); + FinishWrite(mustFlush: true); + } - _outgoingBuffer.Commit(FrameHeader.RstStreamLength); + private async Task SendPingAckAsync(ReadOnlyMemory pingContent) + { + Debug.Assert(pingContent.Length == FrameHeader.PingLength); - await FlushOutgoingBytesAsync().ConfigureAwait(false); - } - finally - { - ReleaseWriteLock(); - } + await StartWriteAsync(FrameHeader.Size + FrameHeader.PingLength).ConfigureAwait(false); + WriteFrameHeader(new FrameHeader(FrameHeader.PingLength, FrameType.Ping, FrameFlags.Ack, 0)); + pingContent.CopyTo(_outgoingBuffer.AvailableMemory); + _outgoingBuffer.Commit(FrameHeader.PingLength); + + FinishWrite(mustFlush: false); + } + + private async Task SendRstStreamAsync(int streamId, Http2ProtocolErrorCode errorCode) + { + await StartWriteAsync(FrameHeader.Size + FrameHeader.RstStreamLength).ConfigureAwait(false); + WriteFrameHeader(new FrameHeader(FrameHeader.RstStreamLength, FrameType.RstStream, FrameFlags.None, streamId)); + BinaryPrimitives.WriteInt32BigEndian(_outgoingBuffer.AvailableSpan, (int)errorCode); + _outgoingBuffer.Commit(FrameHeader.RstStreamLength); + + FinishWrite(mustFlush: true); } private static (ReadOnlyMemory first, ReadOnlyMemory rest) SplitBuffer(ReadOnlyMemory buffer, int maxSize) => @@ -902,28 +952,35 @@ private void WriteHeaders(HttpRequestMessage request) } } - private async ValueTask SendHeadersAsync(HttpRequestMessage request) + private async ValueTask SendHeadersAsync(HttpRequestMessage request, CancellationToken cancellationToken) { // Ensure we don't exceed the max concurrent streams setting. - await _concurrentStreams.RequestCreditAsync(1).ConfigureAwait(false); + await _concurrentStreams.RequestCreditAsync(1, cancellationToken).ConfigureAwait(false); - // Note, HEADERS and CONTINUATION frames must be together, so hold the writer lock across sending all of them. - // We also serialize usage of the header encoder and the header buffer this way. - // (If necessary, we could have a separate semaphore just for creating and encoding header blocks, - // and defer taking the actual _writerLock until we're ready to do the write below.) - await _writerLock.WaitAsync().ConfigureAwait(false); + // We serialize usage of the header encoder and the header buffer separately from the + // write lock + await _headerSerializationLock.WaitAsync(cancellationToken).ConfigureAwait(false); - Http2Stream http2Stream = AddStream(request); - int streamId = http2Stream.StreamId; + Http2Stream http2Stream = null; try { + http2Stream = AddStream(request); + int streamId = http2Stream.StreamId; + + http2Stream = AddStream(request); + streamId = http2Stream.StreamId; + // Generate the entire header block, without framing, into the connection header buffer. WriteHeaders(request); ReadOnlyMemory remaining = _headerBuffer.ActiveMemory; Debug.Assert(remaining.Length > 0); + // Calculate the total number of bytes we're going to use (content + headers). + int totalSize = remaining.Length + (remaining.Length / FrameHeader.MaxLength) * FrameHeader.Size + + (remaining.Length % FrameHeader.MaxLength == 0 ? FrameHeader.Size : 0); + // Split into frames and send. ReadOnlyMemory current; (current, remaining) = SplitBuffer(remaining, FrameHeader.MaxLength); @@ -932,42 +989,43 @@ private async ValueTask SendHeadersAsync(HttpRequestMessage request (remaining.Length == 0 ? FrameFlags.EndHeaders : FrameFlags.None) | (request.Content == null ? FrameFlags.EndStream : FrameFlags.None); - _outgoingBuffer.EnsureAvailableSpace(FrameHeader.Size + current.Length); + // Note, HEADERS and CONTINUATION frames must be together, so hold the writer lock across sending all of them. + await StartWriteAsync(totalSize).ConfigureAwait(false); + WriteFrameHeader(new FrameHeader(current.Length, FrameType.Headers, flags, streamId)); current.CopyTo(_outgoingBuffer.AvailableMemory); _outgoingBuffer.Commit(current.Length); - await FlushOutgoingBytesAsync().ConfigureAwait(false); - while (remaining.Length > 0) { (current, remaining) = SplitBuffer(remaining, FrameHeader.MaxLength); flags = (remaining.Length == 0 ? FrameFlags.EndHeaders : FrameFlags.None); - _outgoingBuffer.EnsureAvailableSpace(FrameHeader.Size + current.Length); WriteFrameHeader(new FrameHeader(current.Length, FrameType.Continuation, flags, streamId)); current.CopyTo(_outgoingBuffer.AvailableMemory); _outgoingBuffer.Commit(current.Length); - - await FlushOutgoingBytesAsync().ConfigureAwait(false); } + + // If this is not the end of the stream, we can put off flushing the buffer + // since we know that there are going to be data frames following. + FinishWrite(mustFlush: (flags & FrameFlags.EndStream) != 0); } catch { - http2Stream.Dispose(); + http2Stream?.Dispose(); throw; } finally { _headerBuffer.Discard(_headerBuffer.ActiveMemory.Length); - _writerLock.Release(); + _headerSerializationLock.Release(); } return http2Stream; } - private async ValueTask SendStreamDataAsync(int streamId, ReadOnlyMemory buffer) + private async Task SendStreamDataAsync(int streamId, ReadOnlyMemory buffer, CancellationToken cancellationToken) { ReadOnlyMemory remaining = buffer; @@ -975,63 +1033,53 @@ private async ValueTask SendStreamDataAsync(int streamId, ReadOnlyMemory b { int frameSize = Math.Min(remaining.Length, FrameHeader.MaxLength); - frameSize = await _connectionWindow.RequestCreditAsync(frameSize).ConfigureAwait(false); + // Once credit had been granted, we want to actually consume those bytes. + frameSize = await _connectionWindow.RequestCreditAsync(frameSize, cancellationToken).ConfigureAwait(false); ReadOnlyMemory current; (current, remaining) = SplitBuffer(remaining, frameSize); - await AcquireWriteLockAsync().ConfigureAwait(false); + // It's possible that a cancellation will occur while we wait for the write lock. In that case, we need to + // return the credit that we have acquired and don't plan to use. try { - _outgoingBuffer.EnsureAvailableSpace(FrameHeader.Size + current.Length); - WriteFrameHeader(new FrameHeader(current.Length, FrameType.Data, FrameFlags.None, streamId)); - current.CopyTo(_outgoingBuffer.AvailableMemory); - _outgoingBuffer.Commit(current.Length); - - await FlushOutgoingBytesAsync().ConfigureAwait(false); + await StartWriteAsync(FrameHeader.Size + current.Length, cancellationToken).ConfigureAwait(false); } - finally + catch (OperationCanceledException) { - ReleaseWriteLock(); + _connectionWindow.AdjustCredit(frameSize); + throw; } + + WriteFrameHeader(new FrameHeader(current.Length, FrameType.Data, FrameFlags.None, streamId)); + current.CopyTo(_outgoingBuffer.AvailableMemory); + _outgoingBuffer.Commit(current.Length); + + FinishWrite(mustFlush: false); } } - private async ValueTask SendEndStreamAsync(int streamId) + private async Task SendEndStreamAsync(int streamId) { - await AcquireWriteLockAsync().ConfigureAwait(false); - try - { - _outgoingBuffer.EnsureAvailableSpace(FrameHeader.Size); - WriteFrameHeader(new FrameHeader(0, FrameType.Data, FrameFlags.EndStream, streamId)); + await StartWriteAsync(FrameHeader.Size).ConfigureAwait(false); - await FlushOutgoingBytesAsync().ConfigureAwait(false); - } - finally - { - ReleaseWriteLock(); - } + WriteFrameHeader(new FrameHeader(0, FrameType.Data, FrameFlags.EndStream, streamId)); + + FinishWrite(mustFlush: true); } - private async ValueTask SendWindowUpdateAsync(int streamId, int amount) + private async Task SendWindowUpdateAsync(int streamId, int amount) { Debug.Assert(amount > 0); - await _writerLock.WaitAsync().ConfigureAwait(false); - try - { - _outgoingBuffer.EnsureAvailableSpace(FrameHeader.Size + FrameHeader.WindowUpdateLength); + // We update both the connection-level and stream-level windows at the same time + await StartWriteAsync(FrameHeader.Size + FrameHeader.WindowUpdateLength).ConfigureAwait(false); - WriteFrameHeader(new FrameHeader(FrameHeader.WindowUpdateLength, FrameType.WindowUpdate, FrameFlags.None, streamId)); - BinaryPrimitives.WriteInt32BigEndian(_outgoingBuffer.AvailableSpan, amount); - _outgoingBuffer.Commit(FrameHeader.WindowUpdateLength); + WriteFrameHeader(new FrameHeader(FrameHeader.WindowUpdateLength, FrameType.WindowUpdate, FrameFlags.None, streamId)); + BinaryPrimitives.WriteInt32BigEndian(_outgoingBuffer.AvailableSpan, amount); + _outgoingBuffer.Commit(FrameHeader.WindowUpdateLength); - await FlushOutgoingBytesAsync().ConfigureAwait(false); - } - finally - { - _writerLock.Release(); - } + FinishWrite(mustFlush: true); } private void ExtendWindow(int amount) @@ -1053,7 +1101,7 @@ private void ExtendWindow(int amount) _pendingWindowUpdate = 0; } - ValueTask ignored = SendWindowUpdateAsync(0, windowUpdateSize); + Task ignored = SendWindowUpdateAsync(0, windowUpdateSize); } private void WriteFrameHeader(FrameHeader frameHeader) @@ -1295,10 +1343,10 @@ public sealed override async Task SendAsync(HttpRequestMess try { // Send headers - http2Stream = await SendHeadersAsync(request).ConfigureAwait(false); + http2Stream = await SendHeadersAsync(request, cancellationToken).ConfigureAwait(false); // Send request body, if any - await http2Stream.SendRequestBodyAsync().ConfigureAwait(false); + await http2Stream.SendRequestBodyAsync(cancellationToken).ConfigureAwait(false); // Wait for response headers to be read. await http2Stream.ReadResponseHeadersAsync().ConfigureAwait(false); @@ -1320,6 +1368,21 @@ public sealed override async Task SendAsync(HttpRequestMess // ISSUE 31315: Determine if/how to expose HTTP2 error codes throw new HttpRequestException(SR.net_http_client_execution_error, e); } + else if (e is OperationCanceledException oce) + { + // If the operation has been canceled after the stream was allocated an ID, send a RST_STREAM. + if (http2Stream != null && http2Stream.StreamId != 0) + { + http2Stream.Cancel(); + } + + if (oce.CancellationToken == cancellationToken) + { + throw; + } + + throw new OperationCanceledException(cancellationToken); + } else { throw; diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs index a93acb9e5e2b..c4923814f331 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs @@ -88,7 +88,7 @@ public Http2Stream(HttpRequestMessage request, Http2Connection connection, int s public HttpRequestMessage Request => _request; public HttpResponseMessage Response => _response; - public async Task SendRequestBodyAsync() + public async Task SendRequestBodyAsync(CancellationToken cancellationToken) { // TODO: ISSUE 31312: Expect: 100-continue and early response handling // Note that in an "early response" scenario, where we get a response before we've finished sending the request body @@ -100,8 +100,11 @@ public async Task SendRequestBodyAsync() { using (Http2WriteStream writeStream = new Http2WriteStream(this)) { - await _request.Content.CopyToAsync(writeStream).ConfigureAwait(false); + await _request.Content.CopyToAsync(writeStream, null, cancellationToken).ConfigureAwait(false); } + + // Don't wait for completion, which could happen asynchronously. + Task ignored = _connection.SendEndStreamAsync(_streamId); } } @@ -364,7 +367,7 @@ private void ExtendWindow(int amount) int windowUpdateSize = _pendingWindowUpdate; _pendingWindowUpdate = 0; - ValueTask ignored = _connection.SendWindowUpdateAsync(_streamId, windowUpdateSize); + Task ignored = _connection.SendWindowUpdateAsync(_streamId, windowUpdateSize); } private (bool wait, int bytesRead) TryReadFromBuffer(Span buffer) @@ -429,18 +432,18 @@ public async ValueTask ReadDataAsync(Memory buffer, CancellationToken return bytesRead; } - private async ValueTask SendDataAsync(ReadOnlyMemory buffer) + private async ValueTask SendDataAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) { ReadOnlyMemory remaining = buffer; while (remaining.Length > 0) { - int sendSize = await _streamWindow.RequestCreditAsync(remaining.Length).ConfigureAwait(false); + int sendSize = await _streamWindow.RequestCreditAsync(remaining.Length, cancellationToken).ConfigureAwait(false); ReadOnlyMemory current; (current, remaining) = SplitBuffer(remaining, sendSize); - await _connection.SendStreamDataAsync(_streamId, current).ConfigureAwait(false); + await _connection.SendStreamDataAsync(_streamId, current, cancellationToken).ConfigureAwait(false); } } @@ -460,6 +463,25 @@ public void Dispose() } } + public void Cancel() + { + bool signalWaiter; + lock (SyncObject) + { + Task ignored = _connection.SendRstStreamAsync(_streamId, Http2ProtocolErrorCode.Cancel); + _state = StreamState.Aborted; + + signalWaiter = _hasWaiter; + _hasWaiter = false; + } + if (signalWaiter) + { + _waitSource.SetResult(true); + } + + _connection.RemoveStream(this); + } + // This object is itself usable as a backing source for ValueTask. Since there's only ever one awaiter // for this object's state transitions at a time, we allow the object to be awaited directly. All functionality // associated with the implementation is just delegated to the ManualResetValueTaskSourceCore. @@ -532,9 +554,6 @@ protected override void Dispose(bool disposing) return; } - // Don't wait for completion, which could happen asynchronously. - ValueTask ignored = http2Stream._connection.SendEndStreamAsync(http2Stream.StreamId); - base.Dispose(disposing); } @@ -551,7 +570,7 @@ public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationTo return new ValueTask(Task.FromException(new ObjectDisposedException(nameof(Http2WriteStream)))); } - return http2Stream.SendDataAsync(buffer); + return http2Stream.SendDataAsync(buffer, cancellationToken); } public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask; diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs index f25c0cde5ab6..3861882eeab5 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs @@ -1071,28 +1071,5 @@ public bool IsUsable( public override bool Equals(object obj) => obj is CachedConnection && Equals((CachedConnection)obj); public override int GetHashCode() => _connection?.GetHashCode() ?? 0; } - - private sealed class TaskCompletionSourceWithCancellation : TaskCompletionSource - { - private CancellationToken _cancellationToken; - - public TaskCompletionSourceWithCancellation() : base(TaskCreationOptions.RunContinuationsAsynchronously) - { - } - - private void OnCancellation() - { - TrySetCanceled(_cancellationToken); - } - - public async Task WaitWithCancellationAsync(CancellationToken cancellationToken) - { - _cancellationToken = cancellationToken; - using (cancellationToken.Register(s => ((TaskCompletionSourceWithCancellation)s).OnCancellation(), this)) - { - return await Task.ConfigureAwait(false); - } - } - } } } diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/TaskCompletionSourceWithCancellation.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/TaskCompletionSourceWithCancellation.cs new file mode 100644 index 000000000000..d462b163590a --- /dev/null +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/TaskCompletionSourceWithCancellation.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.Http +{ + internal sealed class TaskCompletionSourceWithCancellation : TaskCompletionSource + { + private CancellationToken _cancellationToken; + + public TaskCompletionSourceWithCancellation() : base(TaskCreationOptions.RunContinuationsAsynchronously) + { + } + + private void OnCancellation() + { + TrySetCanceled(_cancellationToken); + } + + public async Task WaitWithCancellationAsync(CancellationToken cancellationToken) + { + _cancellationToken = cancellationToken; + using (cancellationToken.UnsafeRegister(s => ((TaskCompletionSourceWithCancellation)s).OnCancellation(), this)) + { + return await Task.ConfigureAwait(false); + } + } + } +} diff --git a/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Cancellation.cs b/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Cancellation.cs index 9c6e64cebbe9..77d4cafb8b39 100644 --- a/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Cancellation.cs +++ b/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Cancellation.cs @@ -343,6 +343,46 @@ await LoopbackServer.CreateServerAsync(async (server, url) => } } + [Fact] + public async Task SendAsync_Cancel_CancellationTokenPropagates() + { + TaskCompletionSource clientCanceled = new TaskCompletionSource(); + await LoopbackServerFactory.CreateClientAndServerAsync( + async uri => + { + var cts = new CancellationTokenSource(); + cts.Cancel(); + + using (HttpClient client = CreateHttpClient()) + { + OperationCanceledException ex = null; + try + { + await client.GetAsync(uri, cts.Token); + } + catch(OperationCanceledException e) + { + ex = e; + } + Assert.True(ex != null, "Expected OperationCancelledException, but no exception was thrown."); + + Assert.True(cts.Token.IsCancellationRequested, "cts token IsCancellationRequested"); + + if (!PlatformDetection.IsFullFramework) + { + // .NET Framework has bug where it doesn't propagate token information. + Assert.True(ex.CancellationToken.IsCancellationRequested, "exception token IsCancellationRequested"); + } + clientCanceled.SetResult(true); + } + }, + async server => + { + Task serverTask = server.HandleRequestAsync(); + await clientCanceled.Task; + }); + } + private async Task ValidateClientCancellationAsync(Func clientBodyAsync) { var stopwatch = Stopwatch.StartNew(); diff --git a/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs b/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs index 3e6c56d87e77..a82e73fed883 100644 --- a/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs +++ b/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs @@ -2,7 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Diagnostics; using System.Net.Test.Common; +using System.Threading; using System.Threading.Tasks; using Xunit; @@ -1108,5 +1110,157 @@ public async Task Http2_MaxConcurrentStreams_LimitEnforced() Assert.Equal(HttpStatusCode.OK, response.StatusCode); } } + + [OuterLoop("Uses Task.Delay")] + [ConditionalFact(nameof(SupportsAlpn))] + public async Task Http2_WaitingForStream_Cancellation() + { + HttpClientHandler handler = CreateHttpClientHandler(); + handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; + + using (var server = Http2LoopbackServer.CreateServer()) + using (var client = new HttpClient(handler)) + { + Task sendTask = client.GetAsync(server.Address); + + await server.EstablishConnectionAsync(); + server.IgnoreWindowUpdates(); + + // Process first request and send response. + int streamId = await server.ReadRequestHeaderAsync(); + await server.SendDefaultResponseAsync(streamId); + + HttpResponseMessage response = await sendTask; + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // Change MaxConcurrentStreams setting and wait for ack. + // (We don't want to send any new requests until we receive the ack, otherwise we may have a timing issue.) + SettingsFrame settingsFrame = new SettingsFrame(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 0 }); + await server.WriteFrameAsync(settingsFrame); + Frame settingsAckFrame = await server.ReadFrameAsync(TimeSpan.FromSeconds(30)); + Assert.Equal(FrameType.Settings, settingsAckFrame.Type); + Assert.Equal(FrameFlags.Ack, settingsAckFrame.Flags); + + // Issue a new request, so that we can cancel it while it waits for a stream. + var cts = new CancellationTokenSource(); + sendTask = client.GetAsync(server.Address, cts.Token); + + // Make sure that the request makes it to the point where it's waiting for a connection. + // It's possible that we'll still initiate a cancellation before it makes it to the queue, + // but it should still behave in the same way if so. + await Task.Delay(500); + + Stopwatch stopwatch = Stopwatch.StartNew(); + cts.Cancel(); + + await Assert.ThrowsAnyAsync(async () => await sendTask); + + // Ensure that the cancellation occurs promptly + stopwatch.Stop(); + Assert.True(stopwatch.ElapsedMilliseconds < 30000); + + // As the client has not allocated a stream ID when the corresponding request is cancelled, + // we do not send a RST stream frame. + } + } + + [ConditionalFact(nameof(SupportsAlpn))] + public async Task Http2_WaitingOnWindowCredit_Cancellation() + { + // The goal of this test is to get the client into the state where it has sent the headers, + // but is waiting on window credit before it will send the body. We then issue a cancellation + // to ensure the request is cancelled as expected. + const int InitialWindowSize = 65535; + const int ContentSize = InitialWindowSize + 1; + + HttpClientHandler handler = CreateHttpClientHandler(); + handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; + TestHelper.EnsureHttp2Feature(handler); + + var content = new ByteArrayContent(TestHelper.GenerateRandomContent(ContentSize)); + + using (var server = Http2LoopbackServer.CreateServer()) + using (var client = new HttpClient(handler)) + { + var cts = new CancellationTokenSource(); + Task clientTask = client.PostAsync(server.Address, content, cts.Token); + + await server.EstablishConnectionAsync(); + + Frame frame = await server.ReadFrameAsync(TimeSpan.FromSeconds(30)); + int streamId = frame.StreamId; + Assert.Equal(FrameType.Headers, frame.Type); + Assert.Equal(FrameFlags.EndHeaders, frame.Flags); + + // Receive up to initial window size + int bytesReceived = 0; + while (bytesReceived < InitialWindowSize) + { + frame = await server.ReadFrameAsync(TimeSpan.FromSeconds(30)); + Assert.Equal(streamId, frame.StreamId); + Assert.Equal(FrameType.Data, frame.Type); + Assert.Equal(FrameFlags.None, frame.Flags); + Assert.True(frame.Length > 0); + + bytesReceived += frame.Length; + } + + // The client is waiting for more credit in order to send the last byte of the + // request body. Test cancellation at this point. + Stopwatch stopwatch = Stopwatch.StartNew(); + + cts.Cancel(); + await Assert.ThrowsAnyAsync(async () => await clientTask); + + // Ensure that the cancellation occurs promptly + stopwatch.Stop(); + Assert.True(stopwatch.ElapsedMilliseconds < 30000); + + // The server should receive a RstStream frame. + frame = await server.ReadFrameAsync(TimeSpan.FromSeconds(30)); + Assert.Equal(FrameType.RstStream, frame.Type); + } + } + + [OuterLoop("Uses Task.Delay")] + [ConditionalFact(nameof(SupportsAlpn))] + public async Task Http2_PendingSend_Cancellation() + { + // The goal of this test is to get the client into the state where it is sending content, + // but the send pends because the TCP window is full. + const int InitialWindowSize = 65535; + const int ContentSize = InitialWindowSize * 2; // Double the default TCP window size. + + HttpClientHandler handler = CreateHttpClientHandler(); + handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; + TestHelper.EnsureHttp2Feature(handler); + + var content = new ByteArrayContent(TestHelper.GenerateRandomContent(ContentSize)); + + using (var server = Http2LoopbackServer.CreateServer()) + using (var client = new HttpClient(handler)) + { + var cts = new CancellationTokenSource(); + + Task clientTask = client.PostAsync(server.Address, content, cts.Token); + + await server.EstablishConnectionAsync(); + + Frame frame = await server.ReadFrameAsync(TimeSpan.FromSeconds(30)); + int streamId = frame.StreamId; + Assert.Equal(FrameType.Headers, frame.Type); + Assert.Equal(FrameFlags.EndHeaders, frame.Flags); + + // Increase the size of the HTTP/2 Window, so that it is large enough to fill the + // TCP window when we do not perform any reads on the server side. + await server.WriteFrameAsync(new WindowUpdateFrame(InitialWindowSize, streamId)); + + // Give the client time to read the window update frame, and for the write to pend. + await Task.Delay(1000); + cts.Cancel(); + + await Assert.ThrowsAnyAsync(async () => await clientTask); + } + } } } diff --git a/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.cs b/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.cs index 63a68daa3b00..4e51c7346a62 100644 --- a/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.cs +++ b/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.cs @@ -367,35 +367,6 @@ public async Task GetAsync_ResponseContentAfterClientAndHandlerDispose_Success() } } - [OuterLoop("Uses external server")] - [Fact] - public async Task SendAsync_Cancel_CancellationTokenPropagates() - { - var cts = new CancellationTokenSource(); - cts.Cancel(); - using (HttpClient client = CreateHttpClient()) - { - var request = new HttpRequestMessage(HttpMethod.Post, Configuration.Http.RemoteEchoServer); - Task t = client.SendAsync(request, cts.Token); - OperationCanceledException ex; - if (PlatformDetection.IsUap) - { - ex = await Assert.ThrowsAsync(() => t); - } - else - { - ex = await Assert.ThrowsAsync(() => t); - } - - Assert.True(cts.Token.IsCancellationRequested, "cts token IsCancellationRequested"); - if (!PlatformDetection.IsFullFramework) - { - // .NET Framework has bug where it doesn't propagate token information. - Assert.True(ex.CancellationToken.IsCancellationRequested, "exception token IsCancellationRequested"); - } - } - } - [SkipOnTargetFramework(TargetFrameworkMonikers.Uap, "UAP HTTP stack doesn't support .Proxy property")] [Theory] [InlineData("[::1234]")] diff --git a/src/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs b/src/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs index c11693c2fa97..9a963d304b22 100644 --- a/src/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs +++ b/src/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs @@ -1665,4 +1665,11 @@ public SocketsHttpHandlerTest_HttpClientHandlerTest_Http2(ITestOutputHelper outp protected override bool UseSocketsHttpHandler => true; protected override bool UseHttp2LoopbackServer => true; } + + [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))] + public sealed class SocketsHttpHandler_HttpClientHandler_Cancellation_Test_Http2 : HttpClientHandler_Cancellation_Test + { + protected override bool UseSocketsHttpHandler => true; + protected override bool UseHttp2LoopbackServer => true; + } }