Skip to content

Commit

Permalink
[stream] add API to stop receiving (fixes: #193)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlaine committed Jul 19, 2021
1 parent 5e39294 commit 55b531a
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 28 deletions.
92 changes: 67 additions & 25 deletions src/aioquic/quic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@
PATH_CHALLENGE_FRAME_CAPACITY = 1 + 8
PATH_RESPONSE_FRAME_CAPACITY = 1 + 8
PING_FRAME_CAPACITY = 1
RESET_STREAM_CAPACITY = 1 + 3 * UINT_VAR_MAX_SIZE
RESET_STREAM_FRAME_CAPACITY = 1 + 3 * UINT_VAR_MAX_SIZE
RETIRE_CONNECTION_ID_CAPACITY = 1 + UINT_VAR_MAX_SIZE
STOP_SENDING_FRAME_CAPACITY = 1 + 2 * UINT_VAR_MAX_SIZE
STREAMS_BLOCKED_CAPACITY = 1 + UINT_VAR_MAX_SIZE
TRANSPORT_CLOSE_FRAME_CAPACITY = 1 + 3 * UINT_VAR_MAX_SIZE # + reason length

Expand Down Expand Up @@ -1055,6 +1056,24 @@ def send_stream_data(
stream = self._get_or_create_stream_for_send(stream_id)
stream.sender.write(data, end_stream=end_stream)

def stop_stream(self, stream_id: int, error_code: int) -> None:
"""
Request termination of the receiving part of a stream.
:param stream_id: The stream's ID.
:param error_code: An error code indicating why the stream is being stopped.
"""
if not self._stream_can_receive(stream_id):
raise ValueError(
"Cannot stop receiving on a local-initiated unidirectional stream"
)

stream = self._streams.get(stream_id, None)
if stream is None:
raise ValueError("Cannot stop receiving on an unknown stream")

stream.receiver.stop(error_code)

# Private

def _alpn_handler(self, alpn_protocol: str) -> None:
Expand Down Expand Up @@ -1204,16 +1223,15 @@ def _get_or_create_stream_for_send(self, stream_id: int) -> QuicStream:
This always occurs as a result of an API call.
"""
if stream_is_client_initiated(stream_id) != self._is_client:
if stream_id not in self._streams:
raise ValueError("Cannot send data on unknown peer-initiated stream")
if stream_is_unidirectional(stream_id):
raise ValueError(
"Cannot send data on peer-initiated unidirectional stream"
)
if not self._stream_can_send(stream_id):
raise ValueError("Cannot send data on peer-initiated unidirectional stream")

stream = self._streams.get(stream_id, None)
if stream is None:
# check initiator
if stream_is_client_initiated(stream_id) != self._is_client:
raise ValueError("Cannot send data on unknown peer-initiated stream")

# determine limits
if stream_is_unidirectional(stream_id):
max_stream_data_local = 0
Expand Down Expand Up @@ -1981,7 +1999,9 @@ def _handle_stop_sending_frame(
# check stream direction
self._assert_stream_can_send(frame_type, stream_id)

self._get_or_create_stream(frame_type, stream_id)
# reset the stream
stream = self._get_or_create_stream(frame_type, stream_id)
stream.sender.reset(error_code=QuicErrorCode.NO_ERROR)

def _handle_stream_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
Expand Down Expand Up @@ -2643,15 +2663,16 @@ def _write_application(
except QuicPacketBuilderStop:
break

# STREAM and RESET_STREAM
for stream in self._streams.values():
if stream.receiver.stop_pending:
# STOP_SENDING
self._write_stop_sending_frame(builder=builder, stream=stream)

if stream.sender.reset_pending:
self._write_reset_stream_frame(
builder=builder,
frame_type=QuicFrameType.RESET_STREAM,
stream=stream,
)
# RESET_STREAM
self._write_reset_stream_frame(builder=builder, stream=stream)
elif not stream.is_blocked and not stream.sender.buffer_is_empty:
# STREAM
self._remote_max_data_used += self._write_stream_frame(
builder=builder,
space=space,
Expand Down Expand Up @@ -2961,26 +2982,25 @@ def _write_ping_frame(
def _write_reset_stream_frame(
self,
builder: QuicPacketBuilder,
frame_type: QuicFrameType,
stream: QuicStream,
) -> None:
buf = builder.start_frame(
frame_type=frame_type,
capacity=RESET_STREAM_CAPACITY,
frame_type=QuicFrameType.RESET_STREAM,
capacity=RESET_STREAM_FRAME_CAPACITY,
handler=stream.sender.on_reset_delivery,
)
reset = stream.sender.get_reset_frame()
buf.push_uint_var(stream.stream_id)
buf.push_uint_var(reset.error_code)
buf.push_uint_var(reset.final_size)
frame = stream.sender.get_reset_frame()
buf.push_uint_var(frame.stream_id)
buf.push_uint_var(frame.error_code)
buf.push_uint_var(frame.final_size)

# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_reset_stream_frame(
error_code=reset.error_code,
final_size=reset.final_size,
stream_id=stream.stream_id,
error_code=frame.error_code,
final_size=frame.final_size,
stream_id=frame.stream_id,
)
)

Expand All @@ -3001,6 +3021,28 @@ def _write_retire_connection_id_frame(
self._quic_logger.encode_retire_connection_id_frame(sequence_number)
)

def _write_stop_sending_frame(
self,
builder: QuicPacketBuilder,
stream: QuicStream,
) -> None:
buf = builder.start_frame(
frame_type=QuicFrameType.STOP_SENDING,
capacity=STOP_SENDING_FRAME_CAPACITY,
handler=stream.receiver.on_stop_sending_delivery,
)
frame = stream.receiver.get_stop_frame()
buf.push_uint_var(frame.stream_id)
buf.push_uint_var(frame.error_code)

# log frame
if self._quic_logger is not None:
builder.quic_logger_frames.append(
self._quic_logger.encode_stop_sending_frame(
error_code=frame.error_code, stream_id=frame.stream_id
)
)

def _write_stream_frame(
self,
builder: QuicPacketBuilder,
Expand Down
7 changes: 7 additions & 0 deletions src/aioquic/quic/packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,13 @@ class QuicFrameType(IntEnum):
class QuicResetStreamFrame:
error_code: int
final_size: int
stream_id: int


@dataclass
class QuicStopSendingFrame:
error_code: int
stream_id: int


@dataclass
Expand Down
52 changes: 50 additions & 2 deletions src/aioquic/quic/stream.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import Optional

from . import events
from .packet import QuicErrorCode, QuicResetStreamFrame, QuicStreamFrame
from .packet import (
QuicErrorCode,
QuicResetStreamFrame,
QuicStopSendingFrame,
QuicStreamFrame,
)
from .packet_builder import QuicDeliveryState
from .rangeset import RangeSet

Expand All @@ -11,15 +16,33 @@ class FinalSizeError(Exception):


class QuicStreamReceiver:
"""
The receive part of a QUIC stream.
It finishes:
- immediately for a send-only stream
- upon reception of a STREAM_RESET frame
- upon reception of a data frame with the FIN bit set
"""

def __init__(self, stream_id: Optional[int], readable: bool) -> None:
self.highest_offset = 0 # the highest offset ever seen
self.is_finished = False
self.stop_pending = False

self._buffer = bytearray()
self._buffer_start = 0 # the offset for the start of the buffer
self._final_size: Optional[int] = None
self._ranges = RangeSet()
self._stream_id = stream_id
self._stop_error_code: Optional[int] = None

def get_stop_frame(self) -> QuicStopSendingFrame:
self.stop_pending = False
return QuicStopSendingFrame(
error_code=self._stop_error_code,
stream_id=self._stream_id,
)

def handle_frame(
self, frame: QuicStreamFrame
Expand Down Expand Up @@ -96,6 +119,20 @@ def handle_reset(
self.is_finished = True
return events.StreamReset(error_code=error_code, stream_id=self._stream_id)

def on_stop_sending_delivery(self, delivery: QuicDeliveryState) -> None:
"""
Callback when a STOP_SENDING is ACK'd.
"""
if delivery != QuicDeliveryState.ACKED:
self.stop_pending = True

def stop(self, error_code: int = QuicErrorCode.NO_ERROR) -> None:
"""
Request the peer stop sending data on the QUIC stream.
"""
self._stop_error_code = error_code
self.stop_pending = True

def _pull_data(self) -> bytes:
"""
Remove data from the front of the buffer.
Expand All @@ -116,6 +153,15 @@ def _pull_data(self) -> bytes:


class QuicStreamSender:
"""
The send part of a QUIC stream.
It finishes:
- immediately for a receive-only stream
- upon acknowledgement of a STREAM_RESET frame
- upon acknowledgement of a data frame with the FIN bit set
"""

def __init__(self, stream_id: Optional[int], writable: bool) -> None:
self.buffer_is_empty = True
self.highest_offset = 0
Expand Down Expand Up @@ -193,7 +239,9 @@ def get_frame(
def get_reset_frame(self) -> QuicResetStreamFrame:
self.reset_pending = False
return QuicResetStreamFrame(
error_code=self._reset_error_code, final_size=self.highest_offset
error_code=self._reset_error_code,
final_size=self.highest_offset,
stream_id=self._stream_id,
)

def on_data_delivery(
Expand Down
46 changes: 45 additions & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,7 +1354,7 @@ def test_handle_new_connection_id_over_limit(self):

def test_handle_new_connection_id_with_retire_prior_to(self):
with client_and_server() as (client, server):
buf = Buffer(capacity=100)
buf = Buffer(capacity=42)
buf.push_uint_var(8) # sequence_number
buf.push_uint_var(2) # retire_prior_to
buf.push_uint_var(8)
Expand Down Expand Up @@ -2266,6 +2266,50 @@ def test_send_reset_stream(self):
client.reset_stream(0, QuicErrorCode.NO_ERROR)
self.assertEqual(roundtrip(client, server), (1, 1))

def test_send_stop_sending(self):
with client_and_server() as (client, server):
# check handshake completed
self.check_handshake(client=client, server=server)

# client creates bidirectional stream
client.send_stream_data(0, b"hello")
self.assertEqual(roundtrip(client, server), (1, 1))

# client sends STOP_SENDING frame
client.stop_stream(0, QuicErrorCode.NO_ERROR)
self.assertEqual(roundtrip(client, server), (1, 1))

# client receives STREAM_RESET frame
event = client.next_event()
self.assertEqual(type(event), events.StreamReset)
self.assertEqual(event.error_code, QuicErrorCode.NO_ERROR)
self.assertEqual(event.stream_id, 0)

def test_send_stop_sending_uni_stream(self):
with client_and_server() as (client, server):
# check handshake completed
self.check_handshake(client=client, server=server)

# client sends STOP_SENDING frame
with self.assertRaises(ValueError) as cm:
client.stop_stream(2, QuicErrorCode.NO_ERROR)
self.assertEqual(
str(cm.exception),
"Cannot stop receiving on a local-initiated unidirectional stream",
)

def test_send_stop_sending_unknown_stream(self):
with client_and_server() as (client, server):
# check handshake completed
self.check_handshake(client=client, server=server)

# client sends STOP_SENDING frame
with self.assertRaises(ValueError) as cm:
client.stop_stream(0, QuicErrorCode.NO_ERROR)
self.assertEqual(
str(cm.exception), "Cannot stop receiving on an unknown stream"
)

def test_send_stream_data_over_max_streams_bidi(self):
with client_and_server() as (client, server):
# create streams
Expand Down
41 changes: 41 additions & 0 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,47 @@ def test_receiver_reset_twice_final_size_error(self):
stream.receiver.handle_reset(final_size=5)
self.assertEqual(str(cm.exception), "Cannot change final size")

def test_receiver_stop(self):
stream = QuicStream()

# stop is requested
stream.receiver.stop(QuicErrorCode.NO_ERROR)
self.assertTrue(stream.receiver.stop_pending)

# stop is sent
frame = stream.receiver.get_stop_frame()
self.assertEqual(frame.error_code, QuicErrorCode.NO_ERROR)
self.assertFalse(stream.receiver.stop_pending)

# stop is acklowledged
stream.receiver.on_stop_sending_delivery(QuicDeliveryState.ACKED)
self.assertFalse(stream.receiver.stop_pending)

def test_receiver_stop_lost(self):
stream = QuicStream()

# stop is requested
stream.receiver.stop(QuicErrorCode.NO_ERROR)
self.assertTrue(stream.receiver.stop_pending)

# stop is sent
frame = stream.receiver.get_stop_frame()
self.assertEqual(frame.error_code, QuicErrorCode.NO_ERROR)
self.assertFalse(stream.receiver.stop_pending)

# stop is lost
stream.receiver.on_stop_sending_delivery(QuicDeliveryState.LOST)
self.assertTrue(stream.receiver.stop_pending)

# stop is sent again
frame = stream.receiver.get_stop_frame()
self.assertEqual(frame.error_code, QuicErrorCode.NO_ERROR)
self.assertFalse(stream.receiver.stop_pending)

# stop is acklowledged
stream.receiver.on_stop_sending_delivery(QuicDeliveryState.ACKED)
self.assertFalse(stream.receiver.stop_pending)

def test_sender_data(self):
stream = QuicStream()
self.assertEqual(stream.sender.next_offset, 0)
Expand Down

0 comments on commit 55b531a

Please sign in to comment.