Skip to content

Commit

Permalink
Fix unclosed 'MemoryObjectReceiveStream' upon exception in 'BaseHTTPM…
Browse files Browse the repository at this point in the history
…iddleware' children (#2813)
  • Loading branch information
Kludex authored Dec 29, 2024
1 parent e16bacb commit 5a10fba
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 37 deletions.
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ filterwarnings = [
"ignore: starlette.middleware.wsgi is deprecated and will be removed in a future release.*:DeprecationWarning",
"ignore: Async generator 'starlette.requests.Request.stream' was garbage collected before it had been exhausted.*:ResourceWarning",
"ignore: Use 'content=<...>' to upload raw bytes/text content.:DeprecationWarning",
# TODO: This warning appeared when we bumped anyio to 4.4.0.
"ignore: Unclosed .MemoryObject(Send|Receive)Stream.:ResourceWarning",
]

[tool.coverage.run]
Expand Down
30 changes: 11 additions & 19 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import typing

import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette._utils import collapse_excgroups
from starlette.requests import ClientDisconnect, Request
Expand Down Expand Up @@ -107,9 +106,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

async def call_next(request: Request) -> Response:
app_exc: Exception | None = None
send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
send_stream, recv_stream = anyio.create_memory_object_stream()

async def receive_or_disconnect() -> Message:
if response_sent.is_set():
Expand All @@ -130,10 +126,6 @@ async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:

return message

async def close_recv_stream_on_response_sent() -> None:
await response_sent.wait()
recv_stream.close()

async def send_no_error(message: Message) -> None:
try:
await send_stream.send(message)
Expand All @@ -144,13 +136,12 @@ async def send_no_error(message: Message) -> None:
async def coro() -> None:
nonlocal app_exc

async with send_stream:
with send_stream:
try:
await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
app_exc = exc

task_group.start_soon(close_recv_stream_on_response_sent)
task_group.start_soon(coro)

try:
Expand All @@ -166,14 +157,13 @@ async def coro() -> None:
assert message["type"] == "http.response.start"

async def body_stream() -> typing.AsyncGenerator[bytes, None]:
async with recv_stream:
async for message in recv_stream:
assert message["type"] == "http.response.body"
body = message.get("body", b"")
if body:
yield body
if not message.get("more_body", False):
break
async for message in recv_stream:
assert message["type"] == "http.response.body"
body = message.get("body", b"")
if body:
yield body
if not message.get("more_body", False):
break

if app_exc is not None:
raise app_exc
Expand All @@ -182,11 +172,13 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
response.raw_headers = message["headers"]
return response

with collapse_excgroups():
send_stream, recv_stream = anyio.create_memory_object_stream[Message]()
with recv_stream, send_stream, collapse_excgroups():
async with anyio.create_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()
recv_stream.close()

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
raise NotImplementedError() # pragma: no cover
Expand Down
30 changes: 14 additions & 16 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import anyio
import anyio.abc
import anyio.from_thread
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from anyio.streams.stapled import StapledObjectStream

from starlette._utils import is_async_callable
Expand Down Expand Up @@ -658,12 +657,12 @@ def __enter__(self) -> TestClient:
def reset_portal() -> None:
self.portal = None

send1: ObjectSendStream[typing.MutableMapping[str, typing.Any] | None]
receive1: ObjectReceiveStream[typing.MutableMapping[str, typing.Any] | None]
send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
send1, receive1 = anyio.create_memory_object_stream(math.inf)
send2, receive2 = anyio.create_memory_object_stream(math.inf)
send1, receive1 = anyio.create_memory_object_stream[
typing.Union[typing.MutableMapping[str, typing.Any], None]
](math.inf)
send2, receive2 = anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]](math.inf)
for channel in (send1, send2, receive1, receive2):
stack.callback(channel.close)
self.stream_send = StapledObjectStream(send1, receive1)
self.stream_receive = StapledObjectStream(send2, receive2)
self.task = portal.start_task_soon(self.lifespan)
Expand Down Expand Up @@ -711,12 +710,11 @@ async def receive() -> typing.Any:
self.task.result()
return message

async with self.stream_send, self.stream_receive:
await self.stream_receive.send({"type": "lifespan.shutdown"})
message = await receive()
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
await receive()
await self.stream_receive.send({"type": "lifespan.shutdown"})
message = await receive()
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
await receive()

0 comments on commit 5a10fba

Please sign in to comment.