From 54f0320b0fd8dc4b848d289bc0a91eebbfe60296 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jeremy=20Lain=C3=A9?= Date: Sun, 18 Jul 2021 23:38:25 +0200 Subject: [PATCH] =?UTF-8?q?[stream]=C2=A0add=20API=20to=20stop=20receiving?= =?UTF-8?q?=20(fixes:=20#193)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/aioquic/quic/connection.py | 92 +++++++++++++++++++++++++--------- src/aioquic/quic/packet.py | 7 +++ src/aioquic/quic/stream.py | 52 ++++++++++++++++++- tests/test_connection.py | 46 ++++++++++++++++- tests/test_stream.py | 16 ++++++ 5 files changed, 185 insertions(+), 28 deletions(-) diff --git a/src/aioquic/quic/connection.py b/src/aioquic/quic/connection.py index 3bb2dc97c..47955e030 100644 --- a/src/aioquic/quic/connection.py +++ b/src/aioquic/quic/connection.py @@ -95,8 +95,9 @@ PATH_CHALLENGE_FRAME_CAPACITY = 1 + 8 PATH_RESPONSE_FRAME_CAPACITY = 1 + 8 PING_FRAME_CAPACITY = 1 -RESET_STREAM_CAPACITY = 1 + 3 * UINT_VAR_MAX_SIZE +RESET_STREAM_FRAME_CAPACITY = 1 + 3 * UINT_VAR_MAX_SIZE RETIRE_CONNECTION_ID_CAPACITY = 1 + UINT_VAR_MAX_SIZE +STOP_SENDING_FRAME_CAPACITY = 1 + 2 * UINT_VAR_MAX_SIZE STREAMS_BLOCKED_CAPACITY = 1 + UINT_VAR_MAX_SIZE TRANSPORT_CLOSE_FRAME_CAPACITY = 1 + 3 * UINT_VAR_MAX_SIZE # + reason length @@ -1055,6 +1056,24 @@ def send_stream_data( stream = self._get_or_create_stream_for_send(stream_id) stream.sender.write(data, end_stream=end_stream) + def stop_stream(self, stream_id: int, error_code: int) -> None: + """ + Request termination of the receiving part of a stream. + + :param stream_id: The stream's ID. + :param error_code: An error code indicating why the stream is being stopped. + """ + if not self._stream_can_receive(stream_id): + raise ValueError( + "Cannot stop receiving on a local-initiated unidirectional stream" + ) + + stream = self._streams.get(stream_id, None) + if stream is None: + raise ValueError("Cannot stop receiving on an unknown stream") + + stream.receiver.stop(error_code) + # Private def _alpn_handler(self, alpn_protocol: str) -> None: @@ -1204,16 +1223,15 @@ def _get_or_create_stream_for_send(self, stream_id: int) -> QuicStream: This always occurs as a result of an API call. """ - if stream_is_client_initiated(stream_id) != self._is_client: - if stream_id not in self._streams: - raise ValueError("Cannot send data on unknown peer-initiated stream") - if stream_is_unidirectional(stream_id): - raise ValueError( - "Cannot send data on peer-initiated unidirectional stream" - ) + if not self._stream_can_send(stream_id): + raise ValueError("Cannot send data on peer-initiated unidirectional stream") stream = self._streams.get(stream_id, None) if stream is None: + # check initiator + if stream_is_client_initiated(stream_id) != self._is_client: + raise ValueError("Cannot send data on unknown peer-initiated stream") + # determine limits if stream_is_unidirectional(stream_id): max_stream_data_local = 0 @@ -1981,7 +1999,9 @@ def _handle_stop_sending_frame( # check stream direction self._assert_stream_can_send(frame_type, stream_id) - self._get_or_create_stream(frame_type, stream_id) + # reset the stream + stream = self._get_or_create_stream(frame_type, stream_id) + stream.sender.reset(error_code=QuicErrorCode.NO_ERROR) def _handle_stream_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer @@ -2643,15 +2663,16 @@ def _write_application( except QuicPacketBuilderStop: break - # STREAM and RESET_STREAM for stream in self._streams.values(): + if stream.receiver.stop_pending: + # STOP_SENDING + self._write_stop_sending_frame(builder=builder, stream=stream) + if stream.sender.reset_pending: - self._write_reset_stream_frame( - builder=builder, - frame_type=QuicFrameType.RESET_STREAM, - stream=stream, - ) + # RESET_STREAM + self._write_reset_stream_frame(builder=builder, stream=stream) elif not stream.is_blocked and not stream.sender.buffer_is_empty: + # STREAM self._remote_max_data_used += self._write_stream_frame( builder=builder, space=space, @@ -2961,26 +2982,25 @@ def _write_ping_frame( def _write_reset_stream_frame( self, builder: QuicPacketBuilder, - frame_type: QuicFrameType, stream: QuicStream, ) -> None: buf = builder.start_frame( - frame_type=frame_type, - capacity=RESET_STREAM_CAPACITY, + frame_type=QuicFrameType.RESET_STREAM, + capacity=RESET_STREAM_FRAME_CAPACITY, handler=stream.sender.on_reset_delivery, ) - reset = stream.sender.get_reset_frame() - buf.push_uint_var(stream.stream_id) - buf.push_uint_var(reset.error_code) - buf.push_uint_var(reset.final_size) + frame = stream.sender.get_reset_frame() + buf.push_uint_var(frame.stream_id) + buf.push_uint_var(frame.error_code) + buf.push_uint_var(frame.final_size) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_reset_stream_frame( - error_code=reset.error_code, - final_size=reset.final_size, - stream_id=stream.stream_id, + error_code=frame.error_code, + final_size=frame.final_size, + stream_id=frame.stream_id, ) ) @@ -3001,6 +3021,28 @@ def _write_retire_connection_id_frame( self._quic_logger.encode_retire_connection_id_frame(sequence_number) ) + def _write_stop_sending_frame( + self, + builder: QuicPacketBuilder, + stream: QuicStream, + ) -> None: + buf = builder.start_frame( + frame_type=QuicFrameType.STOP_SENDING, + capacity=STOP_SENDING_FRAME_CAPACITY, + handler=stream.receiver.on_stop_sending_delivery, + ) + frame = stream.receiver.get_stop_frame() + buf.push_uint_var(frame.stream_id) + buf.push_uint_var(frame.error_code) + + # log frame + if self._quic_logger is not None: + builder.quic_logger_frames.append( + self._quic_logger.encode_stop_sending_frame( + error_code=frame.error_code, stream_id=frame.stream_id + ) + ) + def _write_stream_frame( self, builder: QuicPacketBuilder, diff --git a/src/aioquic/quic/packet.py b/src/aioquic/quic/packet.py index 77d2c8b83..60c639b39 100644 --- a/src/aioquic/quic/packet.py +++ b/src/aioquic/quic/packet.py @@ -465,6 +465,13 @@ class QuicFrameType(IntEnum): class QuicResetStreamFrame: error_code: int final_size: int + stream_id: int + + +@dataclass +class QuicStopSendingFrame: + error_code: int + stream_id: int @dataclass diff --git a/src/aioquic/quic/stream.py b/src/aioquic/quic/stream.py index 433c34ca3..e8d2d1251 100644 --- a/src/aioquic/quic/stream.py +++ b/src/aioquic/quic/stream.py @@ -1,7 +1,12 @@ from typing import Optional from . import events -from .packet import QuicErrorCode, QuicResetStreamFrame, QuicStreamFrame +from .packet import ( + QuicErrorCode, + QuicResetStreamFrame, + QuicStopSendingFrame, + QuicStreamFrame, +) from .packet_builder import QuicDeliveryState from .rangeset import RangeSet @@ -11,15 +16,33 @@ class FinalSizeError(Exception): class QuicStreamReceiver: + """ + The receive part of a QUIC stream. + + It finishes: + - immediately for a send-only stream + - upon reception of a STREAM_RESET frame + - upon reception of a data frame with the FIN bit set + """ + def __init__(self, stream_id: Optional[int], readable: bool) -> None: self.highest_offset = 0 # the highest offset ever seen self.is_finished = False + self.stop_pending = False self._buffer = bytearray() self._buffer_start = 0 # the offset for the start of the buffer self._final_size: Optional[int] = None self._ranges = RangeSet() self._stream_id = stream_id + self._stop_error_code: Optional[int] = None + + def get_stop_frame(self) -> QuicStopSendingFrame: + self.stop_pending = False + return QuicStopSendingFrame( + error_code=self._stop_error_code, + stream_id=self._stream_id, + ) def handle_frame( self, frame: QuicStreamFrame @@ -96,6 +119,20 @@ def handle_reset( self.is_finished = True return events.StreamReset(error_code=error_code, stream_id=self._stream_id) + def on_stop_sending_delivery(self, delivery: QuicDeliveryState) -> None: + """ + Callback when a STOP_SENDING is ACK'd. + """ + if delivery != QuicDeliveryState.ACKED: + self.stop_pending = True + + def stop(self, error_code: int = QuicErrorCode.NO_ERROR) -> None: + """ + Request the peer stop sending data on the QUIC stream. + """ + self._stop_error_code = error_code + self.stop_pending = True + def _pull_data(self) -> bytes: """ Remove data from the front of the buffer. @@ -116,6 +153,15 @@ def _pull_data(self) -> bytes: class QuicStreamSender: + """ + The send part of a QUIC stream. + + It finishes: + - immediately for a receive-only stream + - upon acknowledgement of a STREAM_RESET frame + - upon acknowledgement of a data frame with the FIN bit set + """ + def __init__(self, stream_id: Optional[int], writable: bool) -> None: self.buffer_is_empty = True self.highest_offset = 0 @@ -193,7 +239,9 @@ def get_frame( def get_reset_frame(self) -> QuicResetStreamFrame: self.reset_pending = False return QuicResetStreamFrame( - error_code=self._reset_error_code, final_size=self.highest_offset + error_code=self._reset_error_code, + final_size=self.highest_offset, + stream_id=self._stream_id, ) def on_data_delivery( diff --git a/tests/test_connection.py b/tests/test_connection.py index 82fc59d4d..91e7e411e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1354,7 +1354,7 @@ def test_handle_new_connection_id_over_limit(self): def test_handle_new_connection_id_with_retire_prior_to(self): with client_and_server() as (client, server): - buf = Buffer(capacity=100) + buf = Buffer(capacity=42) buf.push_uint_var(8) # sequence_number buf.push_uint_var(2) # retire_prior_to buf.push_uint_var(8) @@ -2266,6 +2266,50 @@ def test_send_reset_stream(self): client.reset_stream(0, QuicErrorCode.NO_ERROR) self.assertEqual(roundtrip(client, server), (1, 1)) + def test_send_stop_sending(self): + with client_and_server() as (client, server): + # check handshake completed + self.check_handshake(client=client, server=server) + + # client creates bidirectional stream + client.send_stream_data(0, b"hello") + self.assertEqual(roundtrip(client, server), (1, 1)) + + # client sends STOP_SENDING frame + client.stop_stream(0, QuicErrorCode.NO_ERROR) + self.assertEqual(roundtrip(client, server), (1, 1)) + + # client receives STREAM_RESET frame + event = client.next_event() + self.assertEqual(type(event), events.StreamReset) + self.assertEqual(event.error_code, QuicErrorCode.NO_ERROR) + self.assertEqual(event.stream_id, 0) + + def test_send_stop_sending_uni_stream(self): + with client_and_server() as (client, server): + # check handshake completed + self.check_handshake(client=client, server=server) + + # client sends STOP_SENDING frame + with self.assertRaises(ValueError) as cm: + client.stop_stream(2, QuicErrorCode.NO_ERROR) + self.assertEqual( + str(cm.exception), + "Cannot stop receiving on a local-initiated unidirectional stream", + ) + + def test_send_stop_sending_unknown_stream(self): + with client_and_server() as (client, server): + # check handshake completed + self.check_handshake(client=client, server=server) + + # client sends STOP_SENDING frame + with self.assertRaises(ValueError) as cm: + client.stop_stream(0, QuicErrorCode.NO_ERROR) + self.assertEqual( + str(cm.exception), "Cannot stop receiving on an unknown stream" + ) + def test_send_stream_data_over_max_streams_bidi(self): with client_and_server() as (client, server): # create streams diff --git a/tests/test_stream.py b/tests/test_stream.py index 64b045cd3..74364b38b 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -293,6 +293,22 @@ def test_receiver_reset_twice_final_size_error(self): stream.receiver.handle_reset(final_size=5) self.assertEqual(str(cm.exception), "Cannot change final size") + def test_receiver_stop(self): + stream = QuicStream() + + # stop is requested + stream.receiver.stop(QuicErrorCode.NO_ERROR) + self.assertTrue(stream.receiver.stop_pending) + + # stop is sent + frame = stream.receiver.get_stop_frame() + self.assertEqual(frame.error_code, QuicErrorCode.NO_ERROR) + self.assertFalse(stream.receiver.stop_pending) + + # stop is acklowledged + stream.receiver.on_stop_sending_delivery(QuicDeliveryState.ACKED) + self.assertFalse(stream.receiver.stop_pending) + def test_sender_data(self): stream = QuicStream() self.assertEqual(stream.sender.next_offset, 0)