diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 9e7ea3d8..005e9b4b 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -723,6 +723,10 @@ async def keepalive(self) -> None: if self.ping_timeout is not None: try: async with asyncio_timeout(self.ping_timeout): + # connection_lost cancels keepalive immediately + # after setting a ConnectionClosed exception on + # pong_waiter. A CancelledError is raised here, + # not a ConnectionClosed exception. latency = await pong_waiter self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: @@ -733,9 +737,10 @@ async def keepalive(self) -> None: CloseCode.INTERNAL_ERROR, "keepalive ping timeout", ) - break - except ConnectionClosed: - pass + raise AssertionError( + "send_context() should wait for connection_lost(), " + "which cancels keepalive()" + ) except Exception: self.logger.error("keepalive ping failed", exc_info=True) @@ -913,8 +918,7 @@ def connection_lost(self, exc: Exception | None) -> None: self.set_recv_exc(exc) self.recv_messages.close() self.abort_pings() - # If keepalive() was waiting for a pong, abort_pings() terminated it. - # If it was sleeping until the next ping, we need to cancel it now + if self.keepalive_task is not None: self.keepalive_task.cancel() diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 29bb0041..59218de4 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -890,12 +890,25 @@ async def test_pong_explicit_binary(self): @patch("random.getrandbits") async def test_keepalive(self, getrandbits): - """keepalive sends pings.""" + """keepalive sends pings at ping_interval and measures latency.""" self.connection.ping_interval = 2 * MS getrandbits.return_value = 1918987876 self.connection.start_keepalive() + self.assertEqual(self.connection.latency, 0) + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is received. await asyncio.sleep(3 * MS) + # 3 ms: check that the ping frame was sent. await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + self.assertGreater(self.connection.latency, 0) + self.assertLess(self.connection.latency, MS) + + async def test_disable_keepalive(self): + """keepalive is disabled when ping_interval is None.""" + self.connection.ping_interval = None + self.connection.start_keepalive() + await asyncio.sleep(3 * MS) + await self.assertNoFrameSent() @patch("random.getrandbits") async def test_keepalive_times_out(self, getrandbits): @@ -905,13 +918,14 @@ async def test_keepalive_times_out(self, getrandbits): async with self.drop_frames_rcvd(): getrandbits.return_value = 1918987876 self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. await asyncio.sleep(4 * MS) # Exiting the context manager sleeps for MS. - await self.assertFrameSent(Frame(Opcode.PING, b"rand")) - await asyncio.sleep(MS) - await self.assertFrameSent( - Frame(Opcode.CLOSE, b"\x03\xf3keepalive ping timeout") - ) + # 4.x ms: a pong frame is dropped. + # 6 ms: no pong frame is received; the connection is closed. + await asyncio.sleep(2 * MS) + # 7 ms: check that the connection is closed. + self.assertEqual(self.connection.state, State.CLOSED) @patch("random.getrandbits") async def test_keepalive_ignores_timeout(self, getrandbits): @@ -921,18 +935,14 @@ async def test_keepalive_ignores_timeout(self, getrandbits): async with self.drop_frames_rcvd(): getrandbits.return_value = 1918987876 self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. await asyncio.sleep(4 * MS) # Exiting the context manager sleeps for MS. - await self.assertFrameSent(Frame(Opcode.PING, b"rand")) - await asyncio.sleep(MS) - await self.assertNoFrameSent() - - async def test_disable_keepalive(self): - """keepalive is disabled when ping_interval is None.""" - self.connection.ping_interval = None - self.connection.start_keepalive() - await asyncio.sleep(3 * MS) - await self.assertNoFrameSent() + # 4.x ms: a pong frame is dropped. + # 6 ms: no pong frame is received; the connection remains open. + await asyncio.sleep(2 * MS) + # 7 ms: check that the connection is still open. + self.assertEqual(self.connection.state, State.OPEN) async def test_keepalive_terminates_while_sleeping(self): """keepalive task terminates while waiting to send a ping.""" @@ -945,21 +955,27 @@ async def test_keepalive_terminates_while_sleeping(self): async def test_keepalive_terminates_while_waiting_for_pong(self): """keepalive task terminates while waiting to receive a pong.""" self.connection.ping_interval = 2 * MS + self.connection.ping_timeout = 2 * MS async with self.drop_frames_rcvd(): self.connection.start_keepalive() + # 2 ms: keepalive() sends a ping frame. await asyncio.sleep(2 * MS) # Exiting the context manager sleeps for MS. + # 2.x ms: a pong frame is dropped. + # 3 ms: close the connection before ping_timeout elapses. await self.connection.close() self.assertTrue(self.connection.keepalive_task.done()) async def test_keepalive_reports_errors(self): """keepalive reports unexpected errors in logs.""" self.connection.ping_interval = 2 * MS - # Inject a fault by raising an exception in a pending pong waiter. async with self.drop_frames_rcvd(): self.connection.start_keepalive() + # 2 ms: keepalive() sends a ping frame. await asyncio.sleep(2 * MS) # Exiting the context manager sleeps for MS. + # 2.x ms: a pong frame is dropped. + # 3 ms: inject a fault: raise an exception in the pending pong waiter. pong_waiter = next(iter(self.connection.pong_waiters.values()))[0] with self.assertLogs("websockets", logging.ERROR) as logs: pong_waiter.set_exception(Exception("BOOM"))