Skip to content

Commit

Permalink
Remove unnecessary code paths in keepalive().
Browse files Browse the repository at this point in the history
Also add comments in tests to clarify the intended sequence.
  • Loading branch information
aaugustin committed Aug 20, 2024
1 parent 453e55a commit 9d355bf
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 22 deletions.
14 changes: 9 additions & 5 deletions src/websockets/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,10 @@ async def keepalive(self) -> None:
if self.ping_timeout is not None:
try:
async with asyncio_timeout(self.ping_timeout):
# connection_lost cancels keepalive immediately
# after setting a ConnectionClosed exception on
# pong_waiter. A CancelledError is raised here,
# not a ConnectionClosed exception.
latency = await pong_waiter
self.logger.debug("% received keepalive pong")
except asyncio.TimeoutError:
Expand All @@ -733,9 +737,10 @@ async def keepalive(self) -> None:
CloseCode.INTERNAL_ERROR,
"keepalive ping timeout",
)
break
except ConnectionClosed:
pass
raise AssertionError(
"send_context() should wait for connection_lost(), "
"which cancels keepalive()"
)
except Exception:
self.logger.error("keepalive ping failed", exc_info=True)

Expand Down Expand Up @@ -913,8 +918,7 @@ def connection_lost(self, exc: Exception | None) -> None:
self.set_recv_exc(exc)
self.recv_messages.close()
self.abort_pings()
# If keepalive() was waiting for a pong, abort_pings() terminated it.
# If it was sleeping until the next ping, we need to cancel it now

if self.keepalive_task is not None:
self.keepalive_task.cancel()

Expand Down
50 changes: 33 additions & 17 deletions tests/asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,12 +890,25 @@ async def test_pong_explicit_binary(self):

@patch("random.getrandbits")
async def test_keepalive(self, getrandbits):
"""keepalive sends pings."""
"""keepalive sends pings at ping_interval and measures latency."""
self.connection.ping_interval = 2 * MS
getrandbits.return_value = 1918987876
self.connection.start_keepalive()
self.assertEqual(self.connection.latency, 0)
# 2 ms: keepalive() sends a ping frame.
# 2.x ms: a pong frame is received.
await asyncio.sleep(3 * MS)
# 3 ms: check that the ping frame was sent.
await self.assertFrameSent(Frame(Opcode.PING, b"rand"))
self.assertGreater(self.connection.latency, 0)
self.assertLess(self.connection.latency, MS)

async def test_disable_keepalive(self):
"""keepalive is disabled when ping_interval is None."""
self.connection.ping_interval = None
self.connection.start_keepalive()
await asyncio.sleep(3 * MS)
await self.assertNoFrameSent()

@patch("random.getrandbits")
async def test_keepalive_times_out(self, getrandbits):
Expand All @@ -905,13 +918,14 @@ async def test_keepalive_times_out(self, getrandbits):
async with self.drop_frames_rcvd():
getrandbits.return_value = 1918987876
self.connection.start_keepalive()
# 4 ms: keepalive() sends a ping frame.
await asyncio.sleep(4 * MS)
# Exiting the context manager sleeps for MS.
await self.assertFrameSent(Frame(Opcode.PING, b"rand"))
await asyncio.sleep(MS)
await self.assertFrameSent(
Frame(Opcode.CLOSE, b"\x03\xf3keepalive ping timeout")
)
# 4.x ms: a pong frame is dropped.
# 6 ms: no pong frame is received; the connection is closed.
await asyncio.sleep(2 * MS)
# 7 ms: check that the connection is closed.
self.assertEqual(self.connection.state, State.CLOSED)

@patch("random.getrandbits")
async def test_keepalive_ignores_timeout(self, getrandbits):
Expand All @@ -921,18 +935,14 @@ async def test_keepalive_ignores_timeout(self, getrandbits):
async with self.drop_frames_rcvd():
getrandbits.return_value = 1918987876
self.connection.start_keepalive()
# 4 ms: keepalive() sends a ping frame.
await asyncio.sleep(4 * MS)
# Exiting the context manager sleeps for MS.
await self.assertFrameSent(Frame(Opcode.PING, b"rand"))
await asyncio.sleep(MS)
await self.assertNoFrameSent()

async def test_disable_keepalive(self):
"""keepalive is disabled when ping_interval is None."""
self.connection.ping_interval = None
self.connection.start_keepalive()
await asyncio.sleep(3 * MS)
await self.assertNoFrameSent()
# 4.x ms: a pong frame is dropped.
# 6 ms: no pong frame is received; the connection remains open.
await asyncio.sleep(2 * MS)
# 7 ms: check that the connection is still open.
self.assertEqual(self.connection.state, State.OPEN)

async def test_keepalive_terminates_while_sleeping(self):
"""keepalive task terminates while waiting to send a ping."""
Expand All @@ -945,21 +955,27 @@ async def test_keepalive_terminates_while_sleeping(self):
async def test_keepalive_terminates_while_waiting_for_pong(self):
"""keepalive task terminates while waiting to receive a pong."""
self.connection.ping_interval = 2 * MS
self.connection.ping_timeout = 2 * MS
async with self.drop_frames_rcvd():
self.connection.start_keepalive()
# 2 ms: keepalive() sends a ping frame.
await asyncio.sleep(2 * MS)
# Exiting the context manager sleeps for MS.
# 2.x ms: a pong frame is dropped.
# 3 ms: close the connection before ping_timeout elapses.
await self.connection.close()
self.assertTrue(self.connection.keepalive_task.done())

async def test_keepalive_reports_errors(self):
"""keepalive reports unexpected errors in logs."""
self.connection.ping_interval = 2 * MS
# Inject a fault by raising an exception in a pending pong waiter.
async with self.drop_frames_rcvd():
self.connection.start_keepalive()
# 2 ms: keepalive() sends a ping frame.
await asyncio.sleep(2 * MS)
# Exiting the context manager sleeps for MS.
# 2.x ms: a pong frame is dropped.
# 3 ms: inject a fault: raise an exception in the pending pong waiter.
pong_waiter = next(iter(self.connection.pong_waiters.values()))[0]
with self.assertLogs("websockets", logging.ERROR) as logs:
pong_waiter.set_exception(Exception("BOOM"))
Expand Down

0 comments on commit 9d355bf

Please sign in to comment.