From 28f126cb673c258d8369a4fc2bcdf0ea4eef7a22 Mon Sep 17 00:00:00 2001 From: Arcadiy Ivanov Date: Fri, 7 Jun 2024 16:45:58 -0400 Subject: [PATCH] Fixes socket timeout on WS connection not respecting ws_connect's timeouts Added read_timeout property to ResponseHandler to allow override After WS(S) connection is established, adjust `conn.proto.read_timeout` to be the largest of the `read_timeout` and the `ws_connect`'s `timeout` or `receive_timeout`, whichever are specified. fixes #8444 --- CHANGES/8444.bugfix | 2 + CONTRIBUTORS.txt | 1 + aiohttp/client.py | 15 ++++++ aiohttp/client_proto.py | 8 +++ tests/test_client_ws.py | 107 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 133 insertions(+) create mode 100644 CHANGES/8444.bugfix diff --git a/CHANGES/8444.bugfix b/CHANGES/8444.bugfix new file mode 100644 index 00000000000..2fc1d5d829c --- /dev/null +++ b/CHANGES/8444.bugfix @@ -0,0 +1,2 @@ +Fix ``ws_connect`` not respecting ``timeout`` nor ``receive_timeout`` on WS(S) connection. +-- by :user:`arcivanov`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 870819b4b8c..3f4a257a678 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -47,6 +47,7 @@ Anes Abismail Antoine Pietri Anton Kasyanov Anton Zhdan-Pushkin +Arcadiy Ivanov Arie Bovenberg Arseny Timoniq Artem Yushkovskiy diff --git a/aiohttp/client.py b/aiohttp/client.py index 7a4db0db476..ef8ea5167ad 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -939,6 +939,21 @@ async def _ws_connect( assert conn is not None conn_proto = conn.protocol assert conn_proto is not None + + # For WS connection the sock_read must be either receive_timeout + # or timeout (whichever is specified), unless read_timeout is greater + # None == no timeout, i.e. infinite timeout + if ws_timeout.ws_receive is None: + # Reset regardless + conn_proto.read_timeout = None + elif conn_proto.read_timeout is None: + # We're already at no timeout + pass + else: + conn_proto.read_timeout = max( + ws_timeout.ws_receive, conn_proto.read_timeout + ) + transport = conn.transport assert transport is not None reader: FlowControlDataQueue[WSMessage] = FlowControlDataQueue( diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index 7a247e1c591..ff76a7289b5 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -240,6 +240,14 @@ def _reschedule_timeout(self) -> None: def start_timeout(self) -> None: self._reschedule_timeout() + @property + def read_timeout(self) -> Optional[float]: + return self._read_timeout + + @read_timeout.setter + def read_timeout(self, read_timeout: Optional[float]) -> None: + self._read_timeout = read_timeout + def _on_read_timeout(self) -> None: exc = SocketTimeoutError("Timeout on reading data from socket") self.set_exception(exc) diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index 06cf2a12066..51d3245fa35 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -10,6 +10,7 @@ import aiohttp from aiohttp import client, hdrs +from aiohttp.client_ws import ClientWSTimeout from aiohttp.http import WS_KEY from aiohttp.streams import EofStream from aiohttp.test_utils import make_mocked_coro @@ -39,6 +40,7 @@ async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None: hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -54,6 +56,94 @@ async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None: assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] +async def test_ws_connect_read_timeout_is_reset_to_inf( + ws_key: Any, loop: Any, key_data: Any +) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", + } + resp.connection.protocol.read_timeout = 0.5 + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + + res = await aiohttp.ClientSession().ws_connect( + "http://test.org", protocols=("t1", "t2", "chat") + ) + + assert isinstance(res, client.ClientWebSocketResponse) + assert res.protocol == "chat" + assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] + assert resp.connection.protocol.read_timeout is None + + +async def test_ws_connect_read_timeout_stays_inf( + ws_key: Any, loop: Any, key_data: Any +) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", + } + resp.connection.protocol.read_timeout = None + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + + res = await aiohttp.ClientSession().ws_connect( + "http://test.org", + protocols=("t1", "t2", "chat"), + timeout=ClientWSTimeout(0.5), + ) + + assert isinstance(res, client.ClientWebSocketResponse) + assert res.protocol == "chat" + assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] + assert resp.connection.protocol.read_timeout is None + + +async def test_ws_connect_read_timeout_reset_to_max( + ws_key: Any, loop: Any, key_data: Any +) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", + } + resp.connection.protocol.read_timeout = 0.5 + with mock.patch("aiohttp.client.os") as m_os: + with mock.patch("aiohttp.client.ClientSession.request") as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + + res = await aiohttp.ClientSession().ws_connect( + "http://test.org", + protocols=("t1", "t2", "chat"), + timeout=ClientWSTimeout(1.0), + ) + + assert isinstance(res, client.ClientWebSocketResponse) + assert res.protocol == "chat" + assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] + assert resp.connection.protocol.read_timeout == 1.0 + + async def test_ws_connect_with_origin(key_data: Any, loop: Any) -> None: resp = mock.Mock() resp.status = 403 @@ -84,6 +174,7 @@ async def test_ws_connect_with_params(ws_key: Any, loop: Any, key_data: Any) -> hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -111,6 +202,7 @@ def read(self, decode=False): hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -233,6 +325,7 @@ async def mock_get(*args, **kwargs): hdrs.SEC_WEBSOCKET_ACCEPT: accept, hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } + resp.connection.protocol.read_timeout = None return resp with mock.patch("aiohttp.client.os") as m_os: @@ -263,6 +356,7 @@ async def test_close(loop: Any, ws_key: Any, key_data: Any) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -303,6 +397,7 @@ async def test_close_eofstream(loop: Any, ws_key: Any, key_data: Any) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -333,6 +428,7 @@ async def test_close_exc(loop: Any, ws_key: Any, key_data: Any) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -365,6 +461,7 @@ async def test_close_exc2(loop: Any, ws_key: Any, key_data: Any) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -397,6 +494,7 @@ async def test_send_data_after_close(ws_key: Any, key_data: Any, loop: Any) -> N hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -425,6 +523,7 @@ async def test_send_data_type_errors(ws_key: Any, key_data: Any, loop: Any) -> N hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -451,6 +550,7 @@ async def test_reader_read_exception(ws_key: Any, key_data: Any, loop: Any) -> N hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + hresp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -519,6 +619,7 @@ async def test_ws_connect_non_overlapped_protocols( hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -543,6 +644,7 @@ async def test_ws_connect_non_overlapped_protocols_2( hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -567,6 +669,7 @@ async def test_ws_connect_deflate(loop: Any, ws_key: Any, key_data: Any) -> None hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -592,6 +695,7 @@ async def test_ws_connect_deflate_per_message( hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -626,6 +730,7 @@ async def test_ws_connect_deflate_server_not_support( hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -652,6 +757,7 @@ async def test_ws_connect_deflate_notakeover( hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; " "client_no_context_takeover", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -678,6 +784,7 @@ async def test_ws_connect_deflate_client_wbits( hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; " "client_max_window_bits=10", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data