Skip to content

Commit

Permalink
Fixes socket timeout on WS connection not respecting ws_connect's tim…
Browse files Browse the repository at this point in the history
…eouts

Added read_timeout property to ResponseHandler to allow override

After WS(S) connection is established, adjust `conn.proto.read_timeout` to
be the largest of the `read_timeout` and the `ws_connect`'s
`timeout` or `receive_timeout`, whichever are specified.

fixes aio-libs#8444
  • Loading branch information
arcivanov committed Jun 7, 2024
1 parent f662958 commit 28f126c
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGES/8444.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix ``ws_connect`` not respecting ``timeout`` nor ``receive_timeout`` on WS(S) connection.
-- by :user:`arcivanov`.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Anes Abismail
Antoine Pietri
Anton Kasyanov
Anton Zhdan-Pushkin
Arcadiy Ivanov
Arie Bovenberg
Arseny Timoniq
Artem Yushkovskiy
Expand Down
15 changes: 15 additions & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,21 @@ async def _ws_connect(
assert conn is not None
conn_proto = conn.protocol
assert conn_proto is not None

# For WS connection the sock_read must be either receive_timeout
# or timeout (whichever is specified), unless read_timeout is greater
# None == no timeout, i.e. infinite timeout
if ws_timeout.ws_receive is None:
# Reset regardless
conn_proto.read_timeout = None
elif conn_proto.read_timeout is None:
# We're already at no timeout
pass
else:
conn_proto.read_timeout = max(
ws_timeout.ws_receive, conn_proto.read_timeout
)

transport = conn.transport
assert transport is not None
reader: FlowControlDataQueue[WSMessage] = FlowControlDataQueue(
Expand Down
8 changes: 8 additions & 0 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,14 @@ def _reschedule_timeout(self) -> None:
def start_timeout(self) -> None:
self._reschedule_timeout()

@property
def read_timeout(self) -> Optional[float]:
return self._read_timeout

@read_timeout.setter
def read_timeout(self, read_timeout: Optional[float]) -> None:
self._read_timeout = read_timeout

def _on_read_timeout(self) -> None:
exc = SocketTimeoutError("Timeout on reading data from socket")
self.set_exception(exc)
Expand Down
107 changes: 107 additions & 0 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import aiohttp
from aiohttp import client, hdrs
from aiohttp.client_ws import ClientWSTimeout
from aiohttp.http import WS_KEY
from aiohttp.streams import EofStream
from aiohttp.test_utils import make_mocked_coro
Expand Down Expand Up @@ -39,6 +40,7 @@ async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None:
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -54,6 +56,94 @@ async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None:
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]


async def test_ws_connect_read_timeout_is_reset_to_inf(
ws_key: Any, loop: Any, key_data: Any
) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = 0.5
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)

res = await aiohttp.ClientSession().ws_connect(
"http://test.org", protocols=("t1", "t2", "chat")
)

assert isinstance(res, client.ClientWebSocketResponse)
assert res.protocol == "chat"
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]
assert resp.connection.protocol.read_timeout is None


async def test_ws_connect_read_timeout_stays_inf(
ws_key: Any, loop: Any, key_data: Any
) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)

res = await aiohttp.ClientSession().ws_connect(
"http://test.org",
protocols=("t1", "t2", "chat"),
timeout=ClientWSTimeout(0.5),
)

assert isinstance(res, client.ClientWebSocketResponse)
assert res.protocol == "chat"
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]
assert resp.connection.protocol.read_timeout is None


async def test_ws_connect_read_timeout_reset_to_max(
ws_key: Any, loop: Any, key_data: Any
) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = 0.5
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)

res = await aiohttp.ClientSession().ws_connect(
"http://test.org",
protocols=("t1", "t2", "chat"),
timeout=ClientWSTimeout(1.0),
)

assert isinstance(res, client.ClientWebSocketResponse)
assert res.protocol == "chat"
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]
assert resp.connection.protocol.read_timeout == 1.0


async def test_ws_connect_with_origin(key_data: Any, loop: Any) -> None:
resp = mock.Mock()
resp.status = 403
Expand Down Expand Up @@ -84,6 +174,7 @@ async def test_ws_connect_with_params(ws_key: Any, loop: Any, key_data: Any) ->
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down Expand Up @@ -111,6 +202,7 @@ def read(self, decode=False):
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down Expand Up @@ -233,6 +325,7 @@ async def mock_get(*args, **kwargs):
hdrs.SEC_WEBSOCKET_ACCEPT: accept,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
return resp

with mock.patch("aiohttp.client.os") as m_os:
Expand Down Expand Up @@ -263,6 +356,7 @@ async def test_close(loop: Any, ws_key: Any, key_data: Any) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -303,6 +397,7 @@ async def test_close_eofstream(loop: Any, ws_key: Any, key_data: Any) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -333,6 +428,7 @@ async def test_close_exc(loop: Any, ws_key: Any, key_data: Any) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -365,6 +461,7 @@ async def test_close_exc2(loop: Any, ws_key: Any, key_data: Any) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -397,6 +494,7 @@ async def test_send_data_after_close(ws_key: Any, key_data: Any, loop: Any) -> N
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down Expand Up @@ -425,6 +523,7 @@ async def test_send_data_type_errors(ws_key: Any, key_data: Any, loop: Any) -> N
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand All @@ -451,6 +550,7 @@ async def test_reader_read_exception(ws_key: Any, key_data: Any, loop: Any) -> N
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
hresp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -519,6 +619,7 @@ async def test_ws_connect_non_overlapped_protocols(
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -543,6 +644,7 @@ async def test_ws_connect_non_overlapped_protocols_2(
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -567,6 +669,7 @@ async def test_ws_connect_deflate(loop: Any, ws_key: Any, key_data: Any) -> None
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -592,6 +695,7 @@ async def test_ws_connect_deflate_per_message(
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -626,6 +730,7 @@ async def test_ws_connect_deflate_server_not_support(
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -652,6 +757,7 @@ async def test_ws_connect_deflate_notakeover(
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; "
"client_no_context_takeover",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -678,6 +784,7 @@ async def test_ws_connect_deflate_client_wbits(
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; "
"client_max_window_bits=10",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down

0 comments on commit 28f126c

Please sign in to comment.