diff --git a/pyproject.toml b/pyproject.toml index 785c00eff..840ada3a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ filterwarnings = [ 'ignore: \"watchgod\" is deprecated\, you should switch to watchfiles \(`pip install watchfiles`\)\.:DeprecationWarning', "ignore:Uvicorn's native WSGI implementation is deprecated.*:DeprecationWarning", "ignore: 'cgi' is deprecated and slated for removal in Python 3.13:DeprecationWarning", + "ignore: remove second argument of ws_handler:DeprecationWarning:websockets" ] [tool.coverage.run] diff --git a/requirements.txt b/requirements.txt index b500435b4..848e6a290 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ h11 @ git+https://github.com/python-hyper/h11.git@master # Explicit optionals a2wsgi==1.10.6 wsproto==1.2.0 -websockets==12.0 +websockets==13.1 # Packaging build==1.2.1 diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 074c01d03..15ccfdd7d 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -9,6 +9,7 @@ import websockets import websockets.client import websockets.exceptions +from typing_extensions import TypedDict from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory from websockets.typing import Subprotocol @@ -20,7 +21,9 @@ ASGISendCallable, Scope, WebSocketCloseEvent, + WebSocketConnectEvent, WebSocketDisconnectEvent, + WebSocketReceiveEvent, WebSocketResponseStartEvent, ) from uvicorn.config import Config @@ -71,7 +74,7 @@ async def asgi(self): break -async def wsresponse(url): +async def wsresponse(url: str): """ A simple websocket connection request and response helper """ @@ -114,26 +117,21 @@ def app(scope: Scope): "missing or empty sec-websocket-key header", # wsproto "failed to open a websocket connection: missing " "sec-websocket-key header", "failed to open a websocket connection: missing or empty " "sec-websocket-key header", + "failed to open a websocket connection: missing sec-websocket-key header; 'sec-websocket-key'", ] ) async def test_accept_connection(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept"}) - async def open_connection(url): + async def open_connection(url: str): async with websockets.client.connect(url) as websocket: return websocket.open - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}") assert is_open @@ -141,16 +139,10 @@ async def open_connection(url): async def test_shutdown(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept"}) - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config) as server: async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}"): # Attempt shutdown while connection is still open @@ -161,21 +153,15 @@ async def test_supports_permessage_deflate_extension( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept"}) - async def open_connection(url): + async def open_connection(url: str): extension_factories = [ClientPerMessageDeflateFactory()] async with websockets.client.connect(url, extensions=extension_factories) as websocket: return [extension.name for extension in websocket.extensions] - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): extension_names = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}") assert "permessage-deflate" in extension_names @@ -185,7 +171,7 @@ async def test_can_disable_permessage_deflate_extension( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept"}) async def open_connection(url: str): @@ -210,7 +196,7 @@ async def open_connection(url: str): async def test_close_connection(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.close"}) async def open_connection(url: str): @@ -220,13 +206,7 @@ async def open_connection(url: str): return False return True # pragma: no cover - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}") assert not is_open @@ -234,7 +214,7 @@ async def open_connection(url: str): async def test_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): headers = self.scope.get("headers") headers = dict(headers) # type: ignore assert headers[b"host"].startswith(b"127.0.0.1") # type: ignore @@ -245,13 +225,7 @@ async def open_connection(url: str): async with websockets.client.connect(url, extra_headers=[("username", "abraĆ£o")]) as websocket: return websocket.open - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}") assert is_open @@ -259,20 +233,14 @@ async def open_connection(url: str): async def test_extra_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept", "headers": [(b"extra", b"header")]}) async def open_connection(url: str): async with websockets.client.connect(url) as websocket: return websocket.response_headers - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): extra_headers = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}") assert extra_headers.get("extra") == "header" @@ -280,7 +248,7 @@ async def open_connection(url: str): async def test_path_and_raw_path(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): path = self.scope.get("path") raw_path = self.scope.get("raw_path") assert path == "/one/two" @@ -291,13 +259,7 @@ async def open_connection(url: str): async with websockets.client.connect(url) as websocket: return websocket.open - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}/one%2Ftwo") assert is_open @@ -307,7 +269,7 @@ async def test_send_text_data_to_client( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept"}) await self.send({"type": "websocket.send", "text": "123"}) @@ -315,13 +277,7 @@ async def get_data(url: str): async with websockets.client.connect(url) as websocket: return await websocket.recv() - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): data = await get_data(f"ws://127.0.0.1:{unused_tcp_port}") assert data == "123" @@ -331,7 +287,7 @@ async def test_send_binary_data_to_client( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept"}) await self.send({"type": "websocket.send", "bytes": b"123"}) @@ -339,13 +295,7 @@ async def get_data(url: str): async with websockets.client.connect(url) as websocket: return await websocket.recv() - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): data = await get_data(f"ws://127.0.0.1:{unused_tcp_port}") assert data == b"123" @@ -355,7 +305,7 @@ async def test_send_and_close_connection( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept"}) await self.send({"type": "websocket.send", "text": "123"}) await self.send({"type": "websocket.close"}) @@ -370,13 +320,7 @@ async def get_data(url: str): is_open = False return (data, is_open) - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): (data, is_open) = await get_data(f"ws://127.0.0.1:{unused_tcp_port}") assert data == "123" @@ -387,11 +331,12 @@ async def test_send_text_data_to_server( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept"}) - async def websocket_receive(self, message): + async def websocket_receive(self, message: WebSocketReceiveEvent): _text = message.get("text") + assert _text is not None await self.send({"type": "websocket.send", "text": _text}) async def send_text(url: str): @@ -399,13 +344,7 @@ async def send_text(url: str): await websocket.send("abc") return await websocket.recv() - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): data = await send_text(f"ws://127.0.0.1:{unused_tcp_port}") assert data == "abc" @@ -415,11 +354,12 @@ async def test_send_binary_data_to_server( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept"}) - async def websocket_receive(self, message): + async def websocket_receive(self, message: WebSocketReceiveEvent): _bytes = message.get("bytes") + assert _bytes is not None await self.send({"type": "websocket.send", "bytes": _bytes}) async def send_text(url: str): @@ -427,13 +367,7 @@ async def send_text(url: str): await websocket.send(b"abc") return await websocket.recv() - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): data = await send_text(f"ws://127.0.0.1:{unused_tcp_port}") assert data == b"abc" @@ -443,7 +377,7 @@ async def test_send_after_protocol_close( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept"}) await self.send({"type": "websocket.send", "text": "123"}) await self.send({"type": "websocket.close"}) @@ -460,13 +394,7 @@ async def get_data(url: str): is_open = False return (data, is_open) - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): (data, is_open) = await get_data(f"ws://127.0.0.1:{unused_tcp_port}") assert data == "123" @@ -480,13 +408,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable async def connect(url: str): await websockets.client.connect(url) - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: await connect(f"ws://127.0.0.1:{unused_tcp_port}") @@ -502,13 +424,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable async def connect(url: str): await websockets.client.connect(url) - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: await connect(f"ws://127.0.0.1:{unused_tcp_port}") @@ -520,21 +436,12 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable await send({"type": "websocket.accept"}) await send({"type": "websocket.accept"}) - async def connect(url: str): - async with websockets.client.connect(url) as websocket: - _ = await websocket.recv() - - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): - with pytest.raises(websockets.exceptions.ConnectionClosed) as exc_info: - await connect(f"ws://127.0.0.1:{unused_tcp_port}") - assert exc_info.value.code == 1006 + async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket: + with pytest.raises(websockets.exceptions.ConnectionClosed): + _ = await websocket.recv() + assert websocket.close_code == 1006 async def test_asgi_return_value(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): @@ -547,29 +454,16 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable await send({"type": "websocket.accept"}) return 123 - async def connect(url: str): - async with websockets.client.connect(url) as websocket: - _ = await websocket.recv() - - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): - with pytest.raises(websockets.exceptions.ConnectionClosed) as exc_info: - await connect(f"ws://127.0.0.1:{unused_tcp_port}") - assert exc_info.value.code == 1006 + async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket: + with pytest.raises(websockets.exceptions.ConnectionClosed): + _ = await websocket.recv() + assert websocket.close_code == 1006 @pytest.mark.parametrize("code", [None, 1000, 1001]) -@pytest.mark.parametrize( - "reason", - [None, "test", False], - ids=["none_as_reason", "normal_reason", "without_reason"], -) +@pytest.mark.parametrize("reason", [None, "test", False], ids=["none_as_reason", "normal_reason", "without_reason"]) async def test_app_close( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, @@ -595,24 +489,15 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable elif message["type"] == "websocket.disconnect": break - async def websocket_session(url: str): - async with websockets.client.connect(url) as websocket: + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) + async with run_server(config): + async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket: await websocket.ping() await websocket.send("abc") - await websocket.recv() - - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) - async with run_server(config): - with pytest.raises(websockets.exceptions.ConnectionClosed) as exc_info: - await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") - assert exc_info.value.code == (code or 1000) - assert exc_info.value.reason == (reason or "") + with pytest.raises(websockets.exceptions.ConnectionClosed): + await websocket.recv() + assert websocket.close_code == (code or 1000) + assert websocket.close_reason == (reason or "") async def test_client_close(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): @@ -636,13 +521,7 @@ async def websocket_session(url: str): await websocket.send("abc") await websocket.close(code=1001, reason="custom reason") - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") @@ -699,13 +578,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable except OSError: got_disconnect_event = True - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): url = f"ws://127.0.0.1:{unused_tcp_port}" async with websockets.client.connect(url): @@ -743,13 +616,7 @@ async def websocket_session(uri: str): }, ) - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): task = asyncio.create_task(websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")) await asyncio.sleep(0.1) @@ -787,13 +654,7 @@ async def websocket_session(uri: str): websocket = ws_connection await server_shutdown_event.wait() - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): task = asyncio.create_task(websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")) await asyncio.sleep(0.1) @@ -815,7 +676,7 @@ async def test_subprotocols( unused_tcp_port: int, ): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept", "subprotocol": subprotocol}) async def get_subprotocol(url: str): @@ -824,13 +685,7 @@ async def get_subprotocol(url: str): ) as websocket: return websocket.subprotocol - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): accepted_subprotocol = await get_subprotocol(f"ws://127.0.0.1:{unused_tcp_port}") assert accepted_subprotocol == subprotocol @@ -863,18 +718,14 @@ async def test_send_binary_data_to_server_bigger_than_default_on_websockets( unused_tcp_port: int, ): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept"}) - async def websocket_receive(self, message): + async def websocket_receive(self, message: WebSocketReceiveEvent): _bytes = message.get("bytes") + assert _bytes is not None await self.send({"type": "websocket.send", "bytes": _bytes}) - async def send_text(url: str): - async with websockets.client.connect(url, max_size=client_size_sent) as ws: - await ws.send(b"\x01" * client_size_sent) - return await ws.recv() - config = Config( app=App, ws=WebSocketProtocol, @@ -884,13 +735,15 @@ async def send_text(url: str): port=unused_tcp_port, ) async with run_server(config): - if expected_result == 0: - data = await send_text(f"ws://127.0.0.1:{unused_tcp_port}") - assert data == b"\x01" * client_size_sent - else: - with pytest.raises(websockets.exceptions.ConnectionClosedError) as e: - data = await send_text(f"ws://127.0.0.1:{unused_tcp_port}") - assert e.value.code == expected_result + async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}", max_size=client_size_sent) as ws: + await ws.send(b"\x01" * client_size_sent) + if expected_result == 0: + data = await ws.recv() + assert data == b"\x01" * client_size_sent + else: + with pytest.raises(websockets.exceptions.ConnectionClosedError): + await ws.recv() + assert ws.close_code == expected_result async def test_server_reject_connection( @@ -914,34 +767,31 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable # See https://github.com/encode/uvicorn/issues/244 disconnected_message = await receive() - async def websocket_session(url): + async def websocket_session(url: str): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: async with websockets.client.connect(url): pass # pragma: no cover assert exc_info.value.status_code == 403 - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") assert disconnected_message == {"type": "websocket.disconnect", "code": 1006} +class EmptyDict(TypedDict): ... + + async def test_server_reject_connection_with_response( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): - disconnected_message = {} + disconnected_message: WebSocketDisconnectEvent | EmptyDict = {} - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): nonlocal disconnected_message assert scope["type"] == "websocket" - assert "websocket.http.response" in scope["extensions"] + assert "extensions" in scope and "websocket.http.response" in scope["extensions"] # Pull up first recv message. message = await receive() @@ -952,18 +802,12 @@ async def app(scope, receive, send): await response(scope, receive, send) disconnected_message = await receive() - async def websocket_session(url): + async def websocket_session(url: str): response = await wsresponse(url) assert response.status_code == 400 assert response.content == b"goodbye" - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") @@ -994,13 +838,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable ], } ) - await send( - { - "type": "websocket.http.response.body", - "body": b"x" * 10, - "more_body": True, - } - ) + await send({"type": "websocket.http.response.body", "body": b"x" * 10, "more_body": True}) await send({"type": "websocket.http.response.body", "body": b"y" * 10}) disconnected_message = await receive() @@ -1009,13 +847,7 @@ async def websocket_session(url: str): assert response.status_code == 400 assert response.content == (b"x" * 10) + (b"y" * 10) - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") @@ -1027,33 +859,28 @@ async def test_server_reject_connection_with_invalid_status( ): # this test checks that even if there is an error in the response, the server # can successfully send a 500 error back to the client - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "websocket" - assert "websocket.http.response" in scope["extensions"] + assert "extensions" in scope and "websocket.http.response" in scope["extensions"] # Pull up first recv message. message = await receive() assert message["type"] == "websocket.connect" - message = { - "type": "websocket.http.response.start", - "status": 700, # invalid status code - "headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")], - } - await send(message) + await send( + { + "type": "websocket.http.response.start", + "status": 700, # invalid status code + "headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")], + } + ) - async def websocket_session(url): + async def websocket_session(url: str): response = await wsresponse(url) assert response.status_code == 500 assert response.content == b"Internal Server Error" - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") @@ -1071,16 +898,10 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable message = await receive() assert message["type"] == "websocket.connect" - await send( - { - "type": "websocket.http.response.start", - "status": 403, - "headers": [], - } - ) + await send({"type": "websocket.http.response.start", "status": 403, "headers": []}) await send({"type": "websocket.http.response.body", "body": b"hardbody"}) - async def websocket_session(url): + async def websocket_session(url: str): response = await wsresponse(url) assert response.status_code == 403 assert response.content == b"hardbody" @@ -1091,13 +912,7 @@ async def websocket_session(url): # websockets automatically adds a content-length assert response.headers["content-length"] == "8" - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") @@ -1105,15 +920,15 @@ async def websocket_session(url): async def test_server_reject_connection_with_invalid_msg( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "websocket" - assert "websocket.http.response" in scope["extensions"] + assert "extensions" in scope and "websocket.http.response" in scope["extensions"] # Pull up first recv message. - message = await receive() - assert message["type"] == "websocket.connect" + message_rcvd = await receive() + assert message_rcvd["type"] == "websocket.connect" - message = { + message: WebSocketResponseStartEvent = { "type": "websocket.http.response.start", "status": 404, "headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")], @@ -1122,19 +937,13 @@ async def app(scope, receive, send): # send invalid message. This will raise an exception here await send(message) - async def websocket_session(url): + async def websocket_session(url: str): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: async with websockets.client.connect(url): pass # pragma: no cover assert exc_info.value.status_code == 404 - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") @@ -1142,35 +951,30 @@ async def websocket_session(url): async def test_server_reject_connection_with_missing_body( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "websocket" - assert "websocket.http.response" in scope["extensions"] + assert "extensions" in scope and "websocket.http.response" in scope["extensions"] # Pull up first recv message. message = await receive() assert message["type"] == "websocket.connect" - message = { - "type": "websocket.http.response.start", - "status": 404, - "headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")], - } - await send(message) + await send( + { + "type": "websocket.http.response.start", + "status": 404, + "headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")], + } + ) # no further message - async def websocket_session(url): + async def websocket_session(url: str): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: async with websockets.client.connect(url): pass # pragma: no cover assert exc_info.value.status_code == 404 - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") @@ -1211,13 +1015,7 @@ async def websocket_session(url: str): pass # pragma: no cover assert exc_info.value.status_code == 404 - config = Config( - app=app, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") @@ -1229,40 +1027,33 @@ async def websocket_session(url: str): async def test_server_can_read_messages_in_buffer_after_close( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): - frames = [] - disconnect_message = {} + frames: list[bytes] = [] + disconnect_message: WebSocketDisconnectEvent | EmptyDict = {} class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept"}) # Ensure server doesn't start reading frames from read buffer until # after client has sent close frame, but server is still able to # read these frames await asyncio.sleep(0.2) - async def websocket_disconnect(self, message): + async def websocket_disconnect(self, message: WebSocketDisconnectEvent): nonlocal disconnect_message disconnect_message = message - async def websocket_receive(self, message): - frames.append(message.get("bytes")) + async def websocket_receive(self, message: WebSocketReceiveEvent): + _bytes = message.get("bytes") + assert _bytes is not None + frames.append(_bytes) - async def send_text(url: str): - async with websockets.client.connect(url) as websocket: + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) + async with run_server(config): + async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket: await websocket.send(b"abc") await websocket.send(b"abc") await websocket.send(b"abc") - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) - async with run_server(config): - await send_text(f"ws://127.0.0.1:{unused_tcp_port}") - assert frames == [b"abc", b"abc", b"abc"] assert disconnect_message == {"type": "websocket.disconnect", "code": 1000, "reason": ""} @@ -1271,20 +1062,14 @@ async def test_default_server_headers( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept"}) async def open_connection(url: str): async with websockets.client.connect(url) as websocket: return websocket.response_headers - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): headers = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}") assert headers.get("server") == "uvicorn" and "date" in headers @@ -1292,7 +1077,7 @@ async def open_connection(url: str): async def test_no_server_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept"}) async def open_connection(url: str): @@ -1315,7 +1100,7 @@ async def open_connection(url: str): @skip_if_no_wsproto async def test_no_date_header_on_wsproto(http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send({"type": "websocket.accept"}) async def open_connection(url: str): @@ -1339,7 +1124,7 @@ async def test_multiple_server_header( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): await self.send( { "type": "websocket.accept", @@ -1354,25 +1139,19 @@ async def open_connection(url: str): async with websockets.client.connect(url) as websocket: return websocket.response_headers - config = Config( - app=App, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="off", - port=unused_tcp_port, - ) + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): headers = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}") assert headers.get_all("Server") == ["uvicorn", "over-ridden", "another-value"] async def test_lifespan_state(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): - expected_states = [ + expected_states: list[dict[str, typing.Any]] = [ {"a": 123, "b": [1]}, {"a": 123, "b": [1, 2]}, ] - actual_states = [] + actual_states: list[dict[str, typing.Any]] = [] async def lifespan_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): message = await receive() @@ -1385,7 +1164,8 @@ async def lifespan_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISen await send({"type": "lifespan.shutdown.complete"}) class App(WebSocketResponse): - async def websocket_connect(self, message): + async def websocket_connect(self, message: WebSocketConnectEvent): + assert "state" in self.scope actual_states.append(deepcopy(self.scope["state"])) self.scope["state"]["a"] = 456 self.scope["state"]["b"].append(2) @@ -1400,13 +1180,7 @@ async def app_wrapper(scope: Scope, receive: ASGIReceiveCallable, send: ASGISend return await lifespan_app(scope, receive, send) return await App(scope, receive, send) - config = Config( - app=app_wrapper, - ws=ws_protocol_cls, - http=http_protocol_cls, - lifespan="on", - port=unused_tcp_port, - ) + config = Config(app=app_wrapper, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="on", port=unused_tcp_port) async with run_server(config): is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}") assert is_open diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index c0700d4d3..af66c29b3 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -224,9 +224,7 @@ def send_500_response(self) -> None: # itself (see https://github.com/encode/uvicorn/issues/920) self.handshake_started_event.set() - async def ws_handler( # type: ignore[override] - self, protocol: WebSocketServerProtocol, path: str - ) -> Any: + async def ws_handler(self, protocol: WebSocketServerProtocol, path: str) -> Any: # type: ignore[override] """ This is the main handler function for the 'websockets' implementation to call into. We just wait for close then return, and instead allow @@ -359,9 +357,7 @@ async def asgi_send(self, message: ASGISendEvent) -> None: msg = "Unexpected ASGI message '%s', after sending 'websocket.close' " "or response already completed." raise RuntimeError(msg % message_type) - async def asgi_receive( - self, - ) -> WebSocketDisconnectEvent | WebSocketConnectEvent | WebSocketReceiveEvent: + async def asgi_receive(self) -> WebSocketDisconnectEvent | WebSocketConnectEvent | WebSocketReceiveEvent: if not self.connect_sent: self.connect_sent = True return {"type": "websocket.connect"} @@ -378,11 +374,11 @@ async def asgi_receive( try: data = await self.recv() - except ConnectionClosed as exc: + except ConnectionClosed: self.closed_event.set() if self.ws_server.closing: return {"type": "websocket.disconnect", "code": 1012} - return {"type": "websocket.disconnect", "code": exc.code, "reason": exc.reason} + return {"type": "websocket.disconnect", "code": self.close_code or 1005, "reason": self.close_reason} if isinstance(data, str): return {"type": "websocket.receive", "text": data}