Skip to content

Commit

Permalink
Merge commit from fork
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Oct 15, 2024
1 parent e116840 commit fd038f3
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 7 deletions.
11 changes: 7 additions & 4 deletions starlette/formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ class FormMessage(Enum):
class MultipartPart:
content_disposition: bytes | None = None
field_name: str = ""
data: bytes = b""
data: bytearray = field(default_factory=bytearray)
file: UploadFile | None = None
item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)


def _user_safe_decode(src: bytes, codec: str) -> str:
def _user_safe_decode(src: bytes | bytearray, codec: str) -> str:
try:
return src.decode(codec)
except (UnicodeDecodeError, LookupError):
Expand Down Expand Up @@ -117,7 +117,8 @@ async def parse(self) -> FormData:


class MultiPartParser:
max_file_size = 1024 * 1024
max_file_size = 1024 * 1024 # 1MB
max_part_size = 1024 * 1024 # 1MB

def __init__(
self,
Expand Down Expand Up @@ -149,7 +150,9 @@ def on_part_begin(self) -> None:
def on_part_data(self, data: bytes, start: int, end: int) -> None:
message_bytes = data[start:end]
if self._current_part.file is None:
self._current_part.data += message_bytes
if len(self._current_part.data) + len(message_bytes) > self.max_part_size:
raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.")
self._current_part.data.extend(message_bytes)
else:
self._file_parts_to_write.append((self._current_part, message_bytes))

Expand Down
41 changes: 38 additions & 3 deletions tests/test_formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,9 +640,7 @@ def test_max_files_is_customizable_low_raises(
assert res.text == "Too many files. Maximum number of files is 1."


def test_max_fields_is_customizable_high(
test_client_factory: TestClientFactory,
) -> None:
def test_max_fields_is_customizable_high(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(make_app_max_parts(max_fields=2000, max_files=2000))
fields = []
for i in range(2000):
Expand All @@ -664,3 +662,40 @@ def test_max_fields_is_customizable_high(
"content": "",
"content_type": None,
}


@pytest.mark.parametrize(
"app,expectation",
[
(app, pytest.raises(MultiPartException)),
(Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
],
)
def test_max_part_size_exceeds_limit(
app: ASGIApp,
expectation: typing.ContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
boundary = "------------------------4K1ON9fZkj9uCUmqLHRbbR"

multipart_data = (
f"--{boundary}\r\n"
f'Content-Disposition: form-data; name="small"\r\n\r\n'
"small content\r\n"
f"--{boundary}\r\n"
f'Content-Disposition: form-data; name="large"\r\n\r\n'
+ ("x" * 1024 * 1024 + "x") # 1MB + 1 byte of data
+ "\r\n"
f"--{boundary}--\r\n"
).encode("utf-8")

headers = {
"Content-Type": f"multipart/form-data; boundary={boundary}",
"Transfer-Encoding": "chunked",
}

with expectation:
response = client.post("/", data=multipart_data, headers=headers) # type: ignore
assert response.status_code == 400
assert response.text == "Part exceeded maximum size of 1024KB."

0 comments on commit fd038f3

Please sign in to comment.