diff --git a/litestar/connection/request.py b/litestar/connection/request.py index 23c60f0b3c..154af41a29 100644 --- a/litestar/connection/request.py +++ b/litestar/connection/request.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic +from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic, cast from litestar._multipart import parse_content_header, parse_multipart_form from litestar._parsers import parse_url_encoded_form_data @@ -222,17 +222,7 @@ async def form(self) -> FormMultiDict: self._connection_state.form = form_data - # form_data is a dict[str, list[str] | str | UploadFile]. Convert it to a - # list[tuple[str, str | UploadFile]] before passing it to FormMultiDict so - # multi-keys can be accessed properly - items = [] - for k, v in form_data.items(): - if isinstance(v, list): - for sv in v: - items.append((k, sv)) - else: - items.append((k, v)) - self._form = FormMultiDict(items) + self._form = FormMultiDict.from_form_data(cast("dict[str, Any]", form_data)) return self._form diff --git a/litestar/datastructures/multi_dicts.py b/litestar/datastructures/multi_dicts.py index 7702e1a8d5..733be6be0b 100644 --- a/litestar/datastructures/multi_dicts.py +++ b/litestar/datastructures/multi_dicts.py @@ -95,6 +95,27 @@ def copy(self) -> Self: # type: ignore[override] class FormMultiDict(ImmutableMultiDict[Any]): """MultiDict for form data.""" + @classmethod + def from_form_data(cls, form_data: dict[str, list[str] | str | UploadFile]) -> FormMultiDict: + """Create a FormMultiDict from form data. + + Args: + form_data: Form data to create the FormMultiDict from. + + Returns: + A FormMultiDict instance + """ + # Convert form_data to a list[tuple[str, str | UploadFile]] before passing it + # to FormMultiDict so multi-keys can be accessed properly + items = [] + for k, v in form_data.items(): + if not isinstance(v, list): + items.append((k, v)) + else: + for sv in v: + items.append((k, sv)) + return cls(items) + async def close(self) -> None: """Close all files in the multi-dict. diff --git a/litestar/datastructures/upload_file.py b/litestar/datastructures/upload_file.py index 09ad2d32ab..d330a747f9 100644 --- a/litestar/datastructures/upload_file.py +++ b/litestar/datastructures/upload_file.py @@ -93,6 +93,8 @@ async def close(self) -> None: Returns: None. """ + if self.file.closed: + return None if self.rolled_to_disk: return await sync_to_thread(self.file.close) return self.file.close() diff --git a/litestar/routes/http.py b/litestar/routes/http.py index a8e47f1b99..2547fdce9b 100644 --- a/litestar/routes/http.py +++ b/litestar/routes/http.py @@ -5,7 +5,7 @@ from msgspec.msgpack import decode as _decode_msgpack_plain -from litestar.datastructures.upload_file import UploadFile +from litestar.datastructures.multi_dicts import FormMultiDict from litestar.enums import HttpMethod, MediaType, ScopeType from litestar.exceptions import ClientException, ImproperlyConfiguredException, SerializationException from litestar.handlers.http_handlers import HTTPRouteHandler @@ -86,8 +86,10 @@ async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None: if after_response_handler := route_handler.resolve_after_response(): await after_response_handler(request) + if request._form is not Empty: + await request._form.close() if form_data := scope.get("_form", {}): - await self._cleanup_temporary_files(form_data=cast("dict[str, Any]", form_data)) + await FormMultiDict.from_form_data(cast("dict[str, Any]", form_data)).close() def create_handler_map(self) -> None: """Parse the ``router_handlers`` of this route and return a mapping of @@ -258,9 +260,3 @@ def options_handler(scope: Scope) -> Response: include_in_schema=False, sync_to_thread=False, )(options_handler) - - @staticmethod - async def _cleanup_temporary_files(form_data: dict[str, Any]) -> None: - for v in form_data.values(): - if isinstance(v, UploadFile) and not v.file.closed: - await v.close() diff --git a/tests/unit/test_datastructures/test_multi_dicts.py b/tests/unit/test_datastructures/test_multi_dicts.py index 78fec65f69..7ec7f386ca 100644 --- a/tests/unit/test_datastructures/test_multi_dicts.py +++ b/tests/unit/test_datastructures/test_multi_dicts.py @@ -1,7 +1,8 @@ from __future__ import annotations +from unittest.mock import patch + import pytest -from pytest_mock import MockerFixture from litestar.datastructures import UploadFile from litestar.datastructures.multi_dicts import FormMultiDict, ImmutableMultiDict, MultiDict @@ -34,20 +35,19 @@ def test_immutable_multi_dict_as_mutable() -> None: assert multi.mutable_copy().dict() == MultiDict(data).dict() -async def test_form_multi_dict_close(mocker: MockerFixture) -> None: - close = mocker.patch("litestar.datastructures.multi_dicts.UploadFile.close") - +async def test_form_multi_dict_close() -> None: multi = FormMultiDict( [ ("foo", UploadFile(filename="foo", content_type="text/plain")), ("bar", UploadFile(filename="foo", content_type="text/plain")), ] ) - + with patch("litestar.datastructures.multi_dicts.UploadFile.close") as mock_close: + await multi.close() + assert mock_close.call_count == 2 + # calls the real UploadFile.close method to clean up await multi.close() - assert close.call_count == 2 - @pytest.mark.parametrize("type_", [MultiDict, ImmutableMultiDict]) def test_copy(type_: type[MultiDict | ImmutableMultiDict]) -> None: