From 303483412dc5b420d09c1421792b3f8b99c323e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 18:30:36 +0100 Subject: [PATCH] Support recv() after the connection is closed. Fix #1538. --- docs/project/changelog.rst | 7 ++++ src/websockets/asyncio/messages.py | 20 ++++-------- src/websockets/sync/messages.py | 27 ++++++++-------- tests/asyncio/test_connection.py | 8 ++--- tests/asyncio/test_messages.py | 52 ++++++++++++++++++++++++++++++ tests/sync/test_connection.py | 19 ++++------- tests/sync/test_messages.py | 52 ++++++++++++++++++++++++++++++ 7 files changed, 142 insertions(+), 43 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 9e6a9d11..074a81c8 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -34,6 +34,13 @@ notice. .. _14.0: +Bug fixes +......... + +* Once the connection is closed, messages previously received and buffered can + be read in the :mod:`asyncio` and :mod:`threading` implementations, just like + in the legacy implementation. + 14.0 ---- diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 814a3c03..14ea7bf9 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -40,9 +40,11 @@ def put(self, item: T) -> None: if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_result(None) - async def get(self) -> T: + async def get(self, block: bool = True) -> T: """Remove and return an item from the queue, waiting if necessary.""" if not self.queue: + if not block: + raise EOFError("stream of frames ended") assert self.get_waiter is None, "cannot call get() concurrently" self.get_waiter = self.loop.create_future() try: @@ -133,12 +135,8 @@ async def get(self, decode: bool | None = None) -> Data: :meth:`get_iter` concurrently. """ - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution @@ -146,7 +144,7 @@ async def get(self, decode: bool | None = None) -> Data: try: # First frame - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) self.maybe_resume() assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY if decode is None: @@ -156,7 +154,7 @@ async def get(self, decode: bool | None = None) -> Data: # Following frames, for fragmented messages while not frame.fin: try: - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) except asyncio.CancelledError: # Put frames already received back into the queue # so that future calls to get() can return them. @@ -200,12 +198,8 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: :meth:`get_iter` concurrently. """ - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution @@ -216,7 +210,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # First frame try: - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) except asyncio.CancelledError: self.get_in_progress = False raise @@ -236,7 +230,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # previous fragments — we're streaming them. Canceling get_iter() # here will leave the assembler in a stuck state. Future calls to # get() or get_iter() will raise ConcurrencyError. - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) self.maybe_resume() assert frame.opcode is OP_CONT if decode: diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index ce08172b..af8635f1 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -69,10 +69,16 @@ def __init__( def get_next_frame(self, timeout: float | None = None) -> Frame: # Helper to factor out the logic for getting the next frame from the # queue, while handling timeouts and reaching the end of the stream. - try: - frame = self.frames.get(timeout=timeout) - except queue.Empty: - raise TimeoutError(f"timed out in {timeout:.1f}s") from None + if self.closed: + try: + frame = self.frames.get(block=False) + except queue.Empty: + raise EOFError("stream of frames ended") from None + else: + try: + frame = self.frames.get(block=True, timeout=timeout) + except queue.Empty: + raise TimeoutError(f"timed out in {timeout:.1f}s") from None if frame is None: raise EOFError("stream of frames ended") return frame @@ -87,7 +93,7 @@ def reset_queue(self, frames: Iterable[Frame]) -> None: queued = [] try: while True: - queued.append(self.frames.get_nowait()) + queued.append(self.frames.get(block=False)) except queue.Empty: pass for frame in frames: @@ -123,9 +129,6 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ with self.mutex: - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True @@ -194,9 +197,6 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: """ with self.mutex: - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True @@ -288,5 +288,6 @@ def close(self) -> None: self.closed = True - # Unblock get() or get_iter(). - self.frames.put(None) + if self.get_in_progress: + # Unblock get() or get_iter(). + self.frames.put(None) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 902b3b84..b1c57c8c 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -793,14 +793,12 @@ async def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - async def test_close_does_not_wait_for_recv(self): - # Closing the connection discards messages buffered in the assembler. - # This is allowed by the RFC: - # > However, there is no guarantee that the endpoint that has already - # > sent a Close frame will continue to process data. + async def test_close_preserves_queued_messages(self): + """close preserves messages buffered in the assembler.""" await self.remote_connection.send("😀") await self.connection.close() + self.assertEqual(await self.connection.recv(), "😀") with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 181ffd37..566f71ce 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -395,6 +395,58 @@ async def test_get_iter_fails_after_close(self): async for _ in self.assembler.get_iter(): self.fail("no fragment expected") + async def test_get_queued_message_after_close(self): + """get returns a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_iter_queued_message_after_close(self): + """get_iter yields a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_queued_fragmented_message_after_close(self): + """get reassembles a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_iter_queued_fragmented_message_after_close(self): + """get_iter yields a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_partially_queued_fragmented_message_after_close(self): + """get raises EOF on a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_partially_queued_fragmented_message_after_close(self): + """get_iter yields a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + fragments = [] + with self.assertRaises(EOFError): + async for fragment in self.assembler.get_iter(): + fragments.append(fragment) + self.assertEqual(fragments, [b"t", b"e"]) + async def test_put_fails_after_close(self): """put raises EOFError after close is called.""" self.assembler.close() diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index abdfd3f7..408b9697 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -543,17 +543,12 @@ def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - def test_close_does_not_wait_for_recv(self): - # Closing the connection discards messages buffered in the assembler. - # This is allowed by the RFC: - # > However, there is no guarantee that the endpoint that has already - # > sent a Close frame will continue to process data. + def test_close_preserves_queued_messages(self): + """close preserves messages buffered in the assembler.""" self.remote_connection.send("😀") self.connection.close() - close_thread = threading.Thread(target=self.connection.close) - close_thread.start() - + self.assertEqual(self.connection.recv(), "😀") with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -576,10 +571,10 @@ def test_close_idempotency(self): def test_close_idempotency_race_condition(self): """close waits if the connection is already closing.""" - self.connection.close_timeout = 5 * MS + self.connection.close_timeout = 6 * MS def closer(): - with self.delay_frames_rcvd(3 * MS): + with self.delay_frames_rcvd(4 * MS): self.connection.close() close_thread = threading.Thread(target=closer) @@ -591,14 +586,14 @@ def closer(): # Connection isn't closed yet. with self.assertRaises(TimeoutError): - self.connection.recv(timeout=0) + self.connection.recv(timeout=MS) self.connection.close() self.assertNoFrameSent() # Connection is closed now. with self.assertRaises(ConnectionClosedOK): - self.connection.recv(timeout=0) + self.connection.recv(timeout=MS) close_thread.join() diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index 02513894..9ebe4508 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -374,6 +374,58 @@ def test_get_iter_fails_after_close(self): for _ in self.assembler.get_iter(): self.fail("no fragment expected") + def test_get_queued_message_after_close(self): + """get returns a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + message = self.assembler.get() + self.assertEqual(message, "café") + + def test_get_iter_queued_message_after_close(self): + """get_iter yields a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + def test_get_queued_fragmented_message_after_close(self): + """get reassembles a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + self.assembler.close() + message = self.assembler.get() + self.assertEqual(message, b"tea") + + def test_get_iter_queued_fragmented_message_after_close(self): + """get_iter yields a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + def test_get_partially_queued_fragmented_message_after_close(self): + """get raises EOF on a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.get() + + def test_get_iter_partially_queued_fragmented_message_after_close(self): + """get_iter yields a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + fragments = [] + with self.assertRaises(EOFError): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + self.assertEqual(fragments, [b"t", b"e"]) + def test_put_fails_after_close(self): """put raises EOFError after close is called.""" self.assembler.close()