diff --git a/CHANGES/8481.bugfix.rst b/CHANGES/8481.bugfix.rst new file mode 100644 index 00000000000..b185780174e --- /dev/null +++ b/CHANGES/8481.bugfix.rst @@ -0,0 +1,2 @@ +Fixed the incorrect rejection of ``ws://`` and ``wss://`` urls +-- by :user:` AraHaan`. diff --git a/aiohttp/client.py b/aiohttp/client.py index d47d0facc27..b2ee5b40604 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -211,6 +211,8 @@ class ClientTimeout: # https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2 IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"}) HTTP_SCHEMA_SET = frozenset({"http", "https", ""}) +WS_SCHEMA_SET = frozenset({"ws", "wss"}) +ALLOWED_PROTOCOL_SCHEMA_SET = HTTP_SCHEMA_SET | WS_SCHEMA_SET _RetType = TypeVar("_RetType") _CharsetResolver = Callable[[ClientResponse, bytes], str] @@ -505,7 +507,7 @@ async def _request( except ValueError as e: raise InvalidUrlClientError(str_or_url) from e - if url.scheme not in HTTP_SCHEMA_SET: + if url.scheme not in ALLOWED_PROTOCOL_SCHEMA_SET: raise NonHttpUrlClientError(url) skip_headers = set(self._skip_auto_headers) diff --git a/tests/conftest.py b/tests/conftest.py index fcdb482a59f..cc3c108847f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,19 @@ import asyncio +import base64 import os import socket import ssl import sys -from hashlib import md5, sha256 +from hashlib import md5, sha1, sha256 from pathlib import Path from tempfile import TemporaryDirectory +from typing import Any from unittest import mock from uuid import uuid4 import pytest +from aiohttp.http import WS_KEY from aiohttp.test_utils import loop_context try: @@ -208,3 +211,18 @@ def start_connection(): spec_set=True, ) as start_connection_mock: yield start_connection_mock + + +@pytest.fixture +def key_data(): + return os.urandom(16) + + +@pytest.fixture +def key(key_data: Any): + return base64.b64encode(key_data) + + +@pytest.fixture +def ws_key(key: Any): + return base64.b64encode(sha1(key + WS_KEY).digest()).decode() diff --git a/tests/test_client_session.py b/tests/test_client_session.py index 416b6bbce5d..52b4cb2e1c9 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -471,7 +471,61 @@ async def create_connection(req, traces, timeout): c.__del__() -async def test_cookie_jar_usage(loop, aiohttp_client) -> None: +@pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss"]) +async def test_ws_connect_allowed_protocols( + create_session: Any, + create_mocked_conn: Any, + protocol: str, + ws_key: Any, + key_data: Any, +) -> None: + resp = mock.create_autospec(aiohttp.ClientResponse) + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + } + resp.url = URL(f"{protocol}://example.com") + resp.cookies = SimpleCookie() + resp.start = mock.AsyncMock() + + req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) + req_factory = mock.Mock(return_value=req) + req.send = mock.AsyncMock(return_value=resp) + + session = await create_session(request_class=req_factory) + + connections = [] + original_connect = session._connector.connect + + async def connect(req, traces, timeout): + conn = await original_connect(req, traces, timeout) + connections.append(conn) + return conn + + async def create_connection(req, traces, timeout): + return create_mocked_conn() + + connector = session._connector + with mock.patch.object(connector, "connect", connect), mock.patch.object( + connector, "_create_connection", create_connection + ), mock.patch.object(connector, "_release"), mock.patch( + "aiohttp.client.os" + ) as m_os: + m_os.urandom.return_value = key_data + await session.ws_connect(f"{protocol}://example.com") + + # normally called during garbage collection. triggers an exception + # if the connection wasn't already closed + for c in connections: + c.close() + del c + + await session.close() + + +async def test_cookie_jar_usage(loop: Any, aiohttp_client: Any) -> None: req_url = None jar = mock.Mock() diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index f0b7757e420..4be404f7752 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -2,6 +2,7 @@ import base64 import hashlib import os +from typing import Any from unittest import mock import pytest @@ -13,22 +14,7 @@ from aiohttp.test_utils import make_mocked_coro -@pytest.fixture -def key_data(): - return os.urandom(16) - - -@pytest.fixture -def key(key_data): - return base64.b64encode(key_data) - - -@pytest.fixture -def ws_key(key): - return base64.b64encode(hashlib.sha1(key + WS_KEY).digest()).decode() - - -async def test_ws_connect(ws_key, loop, key_data) -> None: +async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None: resp = mock.Mock() resp.status = 101 resp.headers = {