From 5a10fba2990a10def60a4dbcdcf6f19b439b264a Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 29 Dec 2024 13:32:00 +0100 Subject: [PATCH] Fix unclosed 'MemoryObjectReceiveStream' upon exception in 'BaseHTTPMiddleware' children (#2813) --- pyproject.toml | 2 -- starlette/middleware/base.py | 30 +++++++++++------------------- starlette/testclient.py | 30 ++++++++++++++---------------- 3 files changed, 25 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 50a53caf6..95f195c50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,8 +84,6 @@ filterwarnings = [ "ignore: starlette.middleware.wsgi is deprecated and will be removed in a future release.*:DeprecationWarning", "ignore: Async generator 'starlette.requests.Request.stream' was garbage collected before it had been exhausted.*:ResourceWarning", "ignore: Use 'content=<...>' to upload raw bytes/text content.:DeprecationWarning", - # TODO: This warning appeared when we bumped anyio to 4.4.0. - "ignore: Unclosed .MemoryObject(Send|Receive)Stream.:ResourceWarning", ] [tool.coverage.run] diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index f51b13f73..6e37c6f60 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -3,7 +3,6 @@ import typing import anyio -from anyio.abc import ObjectReceiveStream, ObjectSendStream from starlette._utils import collapse_excgroups from starlette.requests import ClientDisconnect, Request @@ -107,9 +106,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def call_next(request: Request) -> Response: app_exc: Exception | None = None - send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]] - recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] - send_stream, recv_stream = anyio.create_memory_object_stream() async def receive_or_disconnect() -> Message: if response_sent.is_set(): @@ -130,10 +126,6 @@ async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T: return message - async def close_recv_stream_on_response_sent() -> None: - await response_sent.wait() - recv_stream.close() - async def send_no_error(message: Message) -> None: try: await send_stream.send(message) @@ -144,13 +136,12 @@ async def send_no_error(message: Message) -> None: async def coro() -> None: nonlocal app_exc - async with send_stream: + with send_stream: try: await self.app(scope, receive_or_disconnect, send_no_error) except Exception as exc: app_exc = exc - task_group.start_soon(close_recv_stream_on_response_sent) task_group.start_soon(coro) try: @@ -166,14 +157,13 @@ async def coro() -> None: assert message["type"] == "http.response.start" async def body_stream() -> typing.AsyncGenerator[bytes, None]: - async with recv_stream: - async for message in recv_stream: - assert message["type"] == "http.response.body" - body = message.get("body", b"") - if body: - yield body - if not message.get("more_body", False): - break + async for message in recv_stream: + assert message["type"] == "http.response.body" + body = message.get("body", b"") + if body: + yield body + if not message.get("more_body", False): + break if app_exc is not None: raise app_exc @@ -182,11 +172,13 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: response.raw_headers = message["headers"] return response - with collapse_excgroups(): + send_stream, recv_stream = anyio.create_memory_object_stream[Message]() + with recv_stream, send_stream, collapse_excgroups(): async with anyio.create_task_group() as task_group: response = await self.dispatch_func(request, call_next) await response(scope, wrapped_receive, send) response_sent.set() + recv_stream.close() async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: raise NotImplementedError() # pragma: no cover diff --git a/starlette/testclient.py b/starlette/testclient.py index 9a0abbd7b..4f9788feb 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -14,7 +14,6 @@ import anyio import anyio.abc import anyio.from_thread -from anyio.abc import ObjectReceiveStream, ObjectSendStream from anyio.streams.stapled import StapledObjectStream from starlette._utils import is_async_callable @@ -658,12 +657,12 @@ def __enter__(self) -> TestClient: def reset_portal() -> None: self.portal = None - send1: ObjectSendStream[typing.MutableMapping[str, typing.Any] | None] - receive1: ObjectReceiveStream[typing.MutableMapping[str, typing.Any] | None] - send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]] - receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] - send1, receive1 = anyio.create_memory_object_stream(math.inf) - send2, receive2 = anyio.create_memory_object_stream(math.inf) + send1, receive1 = anyio.create_memory_object_stream[ + typing.Union[typing.MutableMapping[str, typing.Any], None] + ](math.inf) + send2, receive2 = anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]](math.inf) + for channel in (send1, send2, receive1, receive2): + stack.callback(channel.close) self.stream_send = StapledObjectStream(send1, receive1) self.stream_receive = StapledObjectStream(send2, receive2) self.task = portal.start_task_soon(self.lifespan) @@ -711,12 +710,11 @@ async def receive() -> typing.Any: self.task.result() return message - async with self.stream_send, self.stream_receive: - await self.stream_receive.send({"type": "lifespan.shutdown"}) - message = await receive() - assert message["type"] in ( - "lifespan.shutdown.complete", - "lifespan.shutdown.failed", - ) - if message["type"] == "lifespan.shutdown.failed": - await receive() + await self.stream_receive.send({"type": "lifespan.shutdown"}) + message = await receive() + assert message["type"] in ( + "lifespan.shutdown.complete", + "lifespan.shutdown.failed", + ) + if message["type"] == "lifespan.shutdown.failed": + await receive()