Skip to content

Commit

Permalink
Backport: #8444 into 3.9
Browse files Browse the repository at this point in the history
Please see #8445 for the source PR
  • Loading branch information
arcivanov committed Jun 8, 2024
1 parent e4a63ff commit 6b2a182
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGES/8444.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix ``ws_connect`` not respecting `receive_timeout`` on WS(S) connection.
-- by :user:`arcivanov`.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Anes Abismail
Antoine Pietri
Anton Kasyanov
Anton Zhdan-Pushkin
Arcadiy Ivanov
Arseny Timoniq
Artem Yushkovskiy
Arthur Darcet
Expand Down
10 changes: 10 additions & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,16 @@ async def _ws_connect(
assert conn is not None
conn_proto = conn.protocol
assert conn_proto is not None

# For WS connection the read_timeout must be either receive_timeout or greater
# None == no timeout, i.e. infinite timeout, so None is the max timeout possible
if receive_timeout is None:
# Reset regardless
conn_proto.read_timeout = receive_timeout
elif conn_proto.read_timeout is not None:
# If read_timeout was set check which wins
conn_proto.read_timeout = max(receive_timeout, conn_proto.read_timeout)

transport = conn.transport
assert transport is not None
reader: FlowControlDataQueue[WSMessage] = FlowControlDataQueue(
Expand Down
8 changes: 8 additions & 0 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,14 @@ def _reschedule_timeout(self) -> None:
def start_timeout(self) -> None:
self._reschedule_timeout()

@property
def read_timeout(self) -> Optional[float]:
return self._read_timeout

@read_timeout.setter
def read_timeout(self, read_timeout: Optional[float]) -> None:
self._read_timeout = read_timeout

def _on_read_timeout(self) -> None:
exc = ServerTimeoutError("Timeout on reading data from socket")
self.set_exception(exc)
Expand Down
104 changes: 104 additions & 0 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import base64
import hashlib
import os
from typing import Any
from unittest import mock

import pytest
Expand Down Expand Up @@ -37,6 +38,7 @@ async def test_ws_connect(ws_key, loop, key_data) -> None:
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -52,6 +54,91 @@ async def test_ws_connect(ws_key, loop, key_data) -> None:
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]


async def test_ws_connect_read_timeout_is_reset_to_inf(
ws_key: Any, loop: Any, key_data: Any
) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = 0.5
with mock.patch("aiohttp.client.os") as m_os, mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)

res = await aiohttp.ClientSession().ws_connect(
"http://test.org", protocols=("t1", "t2", "chat")
)

assert isinstance(res, client.ClientWebSocketResponse)
assert res.protocol == "chat"
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]
assert resp.connection.protocol.read_timeout is None


async def test_ws_connect_read_timeout_stays_inf(
ws_key: Any, loop: Any, key_data: Any
) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os, mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)

res = await aiohttp.ClientSession().ws_connect(
"http://test.org",
protocols=("t1", "t2", "chat"),
receive_timeout=0.5,
)

assert isinstance(res, client.ClientWebSocketResponse)
assert res.protocol == "chat"
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]
assert resp.connection.protocol.read_timeout is None


async def test_ws_connect_read_timeout_reset_to_max(
ws_key: Any, loop: Any, key_data: Any
) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = 0.5
with mock.patch("aiohttp.client.os") as m_os, mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)

res = await aiohttp.ClientSession().ws_connect(
"http://test.org",
protocols=("t1", "t2", "chat"),
receive_timeout=1.0,
)

assert isinstance(res, client.ClientWebSocketResponse)
assert res.protocol == "chat"
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]
assert resp.connection.protocol.read_timeout == 1.0


async def test_ws_connect_with_origin(key_data, loop) -> None:
resp = mock.Mock()
resp.status = 403
Expand Down Expand Up @@ -82,6 +169,7 @@ async def test_ws_connect_with_params(ws_key, loop, key_data) -> None:
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -107,6 +195,7 @@ def read(self, decode=False):
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down Expand Up @@ -229,6 +318,7 @@ async def mock_get(*args, **kwargs):
hdrs.SEC_WEBSOCKET_ACCEPT: accept,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
return resp

with mock.patch("aiohttp.client.os") as m_os:
Expand Down Expand Up @@ -259,6 +349,7 @@ async def test_close(loop, ws_key, key_data) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -299,6 +390,7 @@ async def test_close_eofstream(loop, ws_key, key_data) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -329,6 +421,7 @@ async def test_close_exc(loop, ws_key, key_data) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -361,6 +454,7 @@ async def test_close_exc2(loop, ws_key, key_data) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -395,6 +489,7 @@ async def test_send_data_after_close(ws_key, key_data, loop) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down Expand Up @@ -423,6 +518,7 @@ async def test_send_data_type_errors(ws_key, key_data, loop) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -451,6 +547,7 @@ async def test_reader_read_exception(ws_key, key_data, loop) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
hresp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -515,6 +612,7 @@ async def test_ws_connect_non_overlapped_protocols(ws_key, loop, key_data) -> No
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -537,6 +635,7 @@ async def test_ws_connect_non_overlapped_protocols_2(ws_key, loop, key_data) ->
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -561,6 +660,7 @@ async def test_ws_connect_deflate(loop, ws_key, key_data) -> None:
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -584,6 +684,7 @@ async def test_ws_connect_deflate_per_message(loop, ws_key, key_data) -> None:
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -616,6 +717,7 @@ async def test_ws_connect_deflate_server_not_support(loop, ws_key, key_data) ->
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -640,6 +742,7 @@ async def test_ws_connect_deflate_notakeover(loop, ws_key, key_data) -> None:
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; "
"client_no_context_takeover",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -664,6 +767,7 @@ async def test_ws_connect_deflate_client_wbits(loop, ws_key, key_data) -> None:
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; "
"client_max_window_bits=10",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down

0 comments on commit 6b2a182

Please sign in to comment.