Skip to content

Commit

Permalink
Fix: clean up all unclosed UploadFiles
Browse files Browse the repository at this point in the history
  • Loading branch information
gsakkis authored and provinzkraut committed Nov 29, 2024
1 parent ab75076 commit 9329f1d
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 26 deletions.
12 changes: 1 addition & 11 deletions litestar/connection/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,17 +278,7 @@ async def form(self) -> FormMultiDict:

self._connection_state.form = form_data

# form_data is a dict[str, list[str] | str | UploadFile]. Convert it to a
# list[tuple[str, str | UploadFile]] before passing it to FormMultiDict so
# multi-keys can be accessed properly
items = []
for k, v in form_data.items():
if isinstance(v, list):
for sv in v:
items.append((k, sv))
else:
items.append((k, v))
self._form = FormMultiDict(items)
self._form = FormMultiDict.from_form_data(cast("dict[str, Any]", form_data))

return self._form

Expand Down
21 changes: 21 additions & 0 deletions litestar/datastructures/multi_dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,27 @@ def copy(self) -> Self: # type: ignore[override]
class FormMultiDict(ImmutableMultiDict[Any]):
"""MultiDict for form data."""

@classmethod
def from_form_data(cls, form_data: dict[str, list[str] | str | UploadFile]) -> FormMultiDict:
"""Create a FormMultiDict from form data.
Args:
form_data: Form data to create the FormMultiDict from.
Returns:
A FormMultiDict instance
"""
# Convert form_data to a list[tuple[str, str | UploadFile]] before passing it
# to FormMultiDict so multi-keys can be accessed properly
items = []
for k, v in form_data.items():
if not isinstance(v, list):
items.append((k, v))
else:
for sv in v:
items.append((k, sv))
return cls(items)

async def close(self) -> None:
"""Close all files in the multi-dict.
Expand Down
2 changes: 2 additions & 0 deletions litestar/datastructures/upload_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ async def close(self) -> None:
Returns:
None.
"""
if self.file.closed:
return None
if self.rolled_to_disk:
return await sync_to_thread(self.file.close)
return self.file.close()
Expand Down
12 changes: 4 additions & 8 deletions litestar/routes/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from msgspec.msgpack import decode as _decode_msgpack_plain

from litestar.datastructures.upload_file import UploadFile
from litestar.datastructures.multi_dicts import FormMultiDict
from litestar.enums import HttpMethod, MediaType, ScopeType
from litestar.exceptions import ClientException, ImproperlyConfiguredException, SerializationException
from litestar.handlers.http_handlers import HTTPRouteHandler
Expand Down Expand Up @@ -86,8 +86,10 @@ async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None:
if after_response_handler := route_handler.resolve_after_response():
await after_response_handler(request)

if request._form is not Empty:
await request._form.close()
if form_data := scope.get("_form", {}):
await self._cleanup_temporary_files(form_data=cast("dict[str, Any]", form_data))
await FormMultiDict.from_form_data(cast("dict[str, Any]", form_data)).close()

def create_handler_map(self) -> None:
"""Parse the ``router_handlers`` of this route and return a mapping of
Expand Down Expand Up @@ -258,9 +260,3 @@ def options_handler(scope: Scope) -> Response:
include_in_schema=False,
sync_to_thread=False,
)(options_handler)

@staticmethod
async def _cleanup_temporary_files(form_data: dict[str, Any]) -> None:
for v in form_data.values():
if isinstance(v, UploadFile) and not v.file.closed:
await v.close()
14 changes: 7 additions & 7 deletions tests/unit/test_datastructures/test_multi_dicts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from unittest.mock import patch

import pytest
from pytest_mock import MockerFixture

from litestar.datastructures import UploadFile
from litestar.datastructures.multi_dicts import FormMultiDict, ImmutableMultiDict, MultiDict
Expand Down Expand Up @@ -34,20 +35,19 @@ def test_immutable_multi_dict_as_mutable() -> None:
assert multi.mutable_copy().dict() == MultiDict(data).dict()


async def test_form_multi_dict_close(mocker: MockerFixture) -> None:
close = mocker.patch("litestar.datastructures.multi_dicts.UploadFile.close")

async def test_form_multi_dict_close() -> None:
multi = FormMultiDict(
[
("foo", UploadFile(filename="foo", content_type="text/plain")),
("bar", UploadFile(filename="foo", content_type="text/plain")),
]
)

with patch("litestar.datastructures.multi_dicts.UploadFile.close") as mock_close:
await multi.close()
assert mock_close.call_count == 2
# calls the real UploadFile.close method to clean up
await multi.close()

assert close.call_count == 2


@pytest.mark.parametrize("type_", [MultiDict, ImmutableMultiDict])
def test_copy(type_: type[MultiDict | ImmutableMultiDict]) -> None:
Expand Down

0 comments on commit 9329f1d

Please sign in to comment.