From 0a0c1b48f66817aedcf85de86396ecd4efb77f8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=85dne=20Hovda?= Date: Thu, 29 Feb 2024 23:57:18 +0100 Subject: [PATCH 1/6] Fix X-Forwarded-Proto when the proxy already sets it to "ws" or "wss" Minor fix for https://github.com/encode/uvicorn/pull/2043 Traefik already sets the X-Forwarded-Proto headers to ws or wss for websockets. https://github.com/traefik/traefik/blob/c1ef7429771104e79f2e87b236b21495cb5765f0/pkg/middlewares/forwardedheaders/forwarded_header.go#L150 This change should make sure we don't overwrite those values. --- uvicorn/middleware/proxy_headers.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/uvicorn/middleware/proxy_headers.py b/uvicorn/middleware/proxy_headers.py index 28277e1d6..5d100ab92 100644 --- a/uvicorn/middleware/proxy_headers.py +++ b/uvicorn/middleware/proxy_headers.py @@ -63,9 +63,10 @@ async def __call__( headers[b"x-forwarded-proto"].decode("latin1").strip() ) if scope["type"] == "websocket": - scope["scheme"] = ( - "wss" if x_forwarded_proto == "https" else "ws" - ) + if x_forwarded_proto == "http": + scope["scheme"] = "ws" + elif x_forwarded_proto == "https": + scope["scheme"] = "wss" else: scope["scheme"] = x_forwarded_proto From 8a6502264bd0b061da4fb8c3c2fabf111f851b5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=85dne=20Hovda?= Date: Fri, 1 Mar 2024 00:01:47 +0100 Subject: [PATCH 2/6] Fix the logic --- uvicorn/middleware/proxy_headers.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/uvicorn/middleware/proxy_headers.py b/uvicorn/middleware/proxy_headers.py index 5d100ab92..5c001b04f 100644 --- a/uvicorn/middleware/proxy_headers.py +++ b/uvicorn/middleware/proxy_headers.py @@ -62,11 +62,10 @@ async def __call__( x_forwarded_proto = ( headers[b"x-forwarded-proto"].decode("latin1").strip() ) - if scope["type"] == "websocket": - if x_forwarded_proto == "http": - scope["scheme"] = "ws" - elif x_forwarded_proto == "https": - scope["scheme"] = "wss" + if scope["type"] == "websocket" and x_forwarded_proto == "http": + scope["scheme"] = "ws" + elif scope["type"] == "websocket" and x_forwarded_proto == "https": + scope["scheme"] = "wss" else: scope["scheme"] = x_forwarded_proto From 28f9d3c012792626657a51e2e02359de5f963c5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=85dne=20Hovda?= Date: Fri, 1 Mar 2024 00:09:39 +0100 Subject: [PATCH 3/6] Update test_proxy_headers.py Test whether passing "wss" in X-Forwarded-Proto works --- tests/middleware/test_proxy_headers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/middleware/test_proxy_headers.py b/tests/middleware/test_proxy_headers.py index 53a4e70db..bb65f8542 100644 --- a/tests/middleware/test_proxy_headers.py +++ b/tests/middleware/test_proxy_headers.py @@ -142,3 +142,8 @@ async def websocket_app(scope, receive, send): async with websockets.client.connect(url, extra_headers=headers) as websocket: data = await websocket.recv() assert data == "wss://1.2.3.4:0" + + headers = {"X-Forwarded-Proto": "wss", "X-Forwarded-For": "1.2.3.4"} + async with websockets.client.connect(url, extra_headers=headers) as websocket: + data = await websocket.recv() + assert data == "wss://1.2.3.4:0" From 4da45431109869bbb230182d1c1ea745d65028cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=85dne=20Hovda?= Date: Fri, 1 Mar 2024 00:45:07 +0100 Subject: [PATCH 4/6] Simplify the logic (probably more ways to write this... lmk which you prefer) --- uvicorn/middleware/proxy_headers.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/uvicorn/middleware/proxy_headers.py b/uvicorn/middleware/proxy_headers.py index 5c001b04f..ff339c5bf 100644 --- a/uvicorn/middleware/proxy_headers.py +++ b/uvicorn/middleware/proxy_headers.py @@ -62,12 +62,10 @@ async def __call__( x_forwarded_proto = ( headers[b"x-forwarded-proto"].decode("latin1").strip() ) - if scope["type"] == "websocket" and x_forwarded_proto == "http": - scope["scheme"] = "ws" - elif scope["type"] == "websocket" and x_forwarded_proto == "https": - scope["scheme"] = "wss" - else: - scope["scheme"] = x_forwarded_proto + scope["scheme"] = x_forwarded_proto + + if scope["type"] == "websocket": + scope["scheme"] = x_forwarded_proto.replace("http", "ws") if b"x-forwarded-for" in headers: # Determine the client address from the last trusted IP in the From 4de216e9a5a6d7d06a10cf3955f2fd1e1757db0d Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 1 Mar 2024 09:17:12 +0100 Subject: [PATCH 5/6] Update tests and min implementation --- tests/middleware/test_proxy_headers.py | 20 +++++++++++++------- uvicorn/middleware/proxy_headers.py | 3 ++- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/middleware/test_proxy_headers.py b/tests/middleware/test_proxy_headers.py index bb65f8542..f86573fb5 100644 --- a/tests/middleware/test_proxy_headers.py +++ b/tests/middleware/test_proxy_headers.py @@ -115,7 +115,18 @@ async def test_proxy_headers_invalid_x_forwarded_for() -> None: @pytest.mark.anyio +@pytest.mark.parametrize( + "x_forwarded_proto,addr", + [ + ("http", "ws://1.2.3.4:0"), + ("https", "wss://1.2.3.4:0"), + ("ws", "ws://1.2.3.4:0"), + ("wss", "wss://1.2.3.4:0"), + ], +) async def test_proxy_headers_websocket_x_forwarded_proto( + x_forwarded_proto: str, + addr: str, ws_protocol_cls: "Type[WSProtocol | WebSocketProtocol]", http_protocol_cls: "Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, @@ -138,12 +149,7 @@ async def websocket_app(scope, receive, send): async with run_server(config): url = f"ws://127.0.0.1:{unused_tcp_port}" - headers = {"X-Forwarded-Proto": "https", "X-Forwarded-For": "1.2.3.4"} - async with websockets.client.connect(url, extra_headers=headers) as websocket: - data = await websocket.recv() - assert data == "wss://1.2.3.4:0" - - headers = {"X-Forwarded-Proto": "wss", "X-Forwarded-For": "1.2.3.4"} + headers = {"X-Forwarded-Proto": x_forwarded_proto, "X-Forwarded-For": "1.2.3.4"} async with websockets.client.connect(url, extra_headers=headers) as websocket: data = await websocket.recv() - assert data == "wss://1.2.3.4:0" + assert data == addr diff --git a/uvicorn/middleware/proxy_headers.py b/uvicorn/middleware/proxy_headers.py index ff339c5bf..7b5e8299a 100644 --- a/uvicorn/middleware/proxy_headers.py +++ b/uvicorn/middleware/proxy_headers.py @@ -62,10 +62,11 @@ async def __call__( x_forwarded_proto = ( headers[b"x-forwarded-proto"].decode("latin1").strip() ) - scope["scheme"] = x_forwarded_proto if scope["type"] == "websocket": scope["scheme"] = x_forwarded_proto.replace("http", "ws") + else: + scope["scheme"] = x_forwarded_proto if b"x-forwarded-for" in headers: # Determine the client address from the last trusted IP in the From 16ac4bfe097bda63c5ef17b212536feea17ed943 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 1 Mar 2024 09:17:47 +0100 Subject: [PATCH 6/6] Remove new line --- uvicorn/middleware/proxy_headers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/uvicorn/middleware/proxy_headers.py b/uvicorn/middleware/proxy_headers.py index 7b5e8299a..1c254416e 100644 --- a/uvicorn/middleware/proxy_headers.py +++ b/uvicorn/middleware/proxy_headers.py @@ -62,7 +62,6 @@ async def __call__( x_forwarded_proto = ( headers[b"x-forwarded-proto"].decode("latin1").strip() ) - if scope["type"] == "websocket": scope["scheme"] = x_forwarded_proto.replace("http", "ws") else: