Skip to content

Commit

Permalink
Refactor the Multipart parsing into a Sans-IO layer
Browse files Browse the repository at this point in the history
This allows it to be used in async (ASGI) frameworks. It also
hopefully makes the code a little clearer to follow.

This removes the ``Content-Transfer-Encoding`` support as RFC7578
states it is deprecated and

    Currently, no deployed implementations that send such bodies have
    been discovered.

This removes the IE6 name fix as IE6 is no longer supported (either
directly or by Werkzeug).

This requires the dataclasses backport, which I think is ok as
dataclasses are in the stdlib, and 3.6 has less than a year till EoL.

The API is based on that successfully used by h11.
  • Loading branch information
pgjones committed Jan 25, 2021
1 parent 20dae30 commit 90f54f3
Show file tree
Hide file tree
Showing 10 changed files with 417 additions and 401 deletions.
2 changes: 0 additions & 2 deletions docs/http.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,5 +158,3 @@ environments for unittesting you might want to use the
.. autoclass:: FormDataParser

.. autofunction:: parse_form_data

.. autofunction:: parse_multipart_headers
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from setuptools import setup

# Metadata goes in setup.cfg. These are here for GitHub's dependency graph.
setup(name="Werkzeug", extras_require={"watchdog": ["watchdog"]})
setup(
name="Werkzeug",
install_requires=["dataclasses; python_version < '3.7'"],
extras_require={"watchdog": ["watchdog"]},
)
2 changes: 1 addition & 1 deletion src/werkzeug/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _to_str(x, charset=_default_encoding, errors="strict", allow_none_charset=Fa
if x is None or isinstance(x, str):
return x

if not isinstance(x, bytes):
if not isinstance(x, (bytes, bytearray)):
return str(x)

if charset is None:
Expand Down
308 changes: 73 additions & 235 deletions src/werkzeug/formparser.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
import codecs
import typing as t
import warnings
from functools import update_wrapper
from io import BytesIO
from itertools import chain
from itertools import repeat
from itertools import tee
from typing import Union

from . import exceptions
from ._internal import _to_str
from .datastructures import FileStorage
from .datastructures import Headers
from .datastructures import MultiDict
from .http import parse_options_header
from .sansio.multipart import Data
from .sansio.multipart import Epilogue
from .sansio.multipart import Field
from .sansio.multipart import File
from .sansio.multipart import MultipartDecoder
from .sansio.multipart import NeedData
from .urls import url_decode_stream
from .wsgi import _make_chunk_iter
from .wsgi import get_content_length
from .wsgi import get_input_stream
from .wsgi import make_line_iter

# there are some platforms where SpooledTemporaryFile is not available.
# In that case we need to provide a fallback.
Expand Down Expand Up @@ -323,12 +328,16 @@ def _line_parse(line: str) -> t.Tuple[str, bool]:
def parse_multipart_headers(iterable: t.Iterable[bytes]) -> Headers:
"""Parses multipart headers from an iterable that yields lines (including
the trailing newline symbol). The iterable has to be newline terminated.
The iterable will stop at the line where the headers ended so it can be
further consumed.
:param iterable: iterable of strings that are newline terminated
"""
warnings.warn(
"'parse_multipart_headers' is deprecated and will be removed in"
" Werkzeug version 2.1.",
DeprecationWarning,
stacklevel=2,
)
result: t.List[t.Tuple[str, str]] = []

for b_line in iterable:
Expand All @@ -354,12 +363,6 @@ def parse_multipart_headers(iterable: t.Iterable[bytes]) -> Headers:
return Headers(result)


_begin_form = "begin_form"
_begin_file = "begin_file"
_cont = "cont"
_end = "end"


class MultiPartParser:
def __init__(
self,
Expand Down Expand Up @@ -392,43 +395,9 @@ def __init__(
assert buffer_size >= 1024, "buffer size has to be at least 1KB"
self.buffer_size = buffer_size

def _fix_ie_filename(self, filename: str) -> str:
"""Internet Explorer 6 transmits the full file name if a file is
uploaded. This function strips the full path if it thinks the
filename is Windows-like absolute.
"""
if filename[1:3] == ":\\" or filename[:2] == "\\\\":
return filename.split("\\")[-1]

return filename

def _find_terminator(self, iterator: t.Iterable[bytes]) -> bytes:
"""The terminator might have some additional newlines before it.
There is at least one application that sends additional newlines
before headers (the python setuptools package).
"""
for line in iterator:
if not line:
break

line = line.strip()

if line:
return line

return b""

def fail(self, message: str) -> "t.NoReturn":
raise ValueError(message)

def get_part_encoding(self, headers: Headers) -> t.Optional[str]:
transfer_encoding: t.Optional[str] = headers.get("content-transfer-encoding")

if transfer_encoding in {"base64", "quoted-printable"}:
return transfer_encoding

return None

def get_part_charset(self, headers: Headers) -> str:
# Figure out input charset for current part
content_type = headers.get("content-type")
Expand All @@ -440,209 +409,78 @@ def get_part_charset(self, headers: Headers) -> str:
return self.charset

def start_file_streaming(
self, filename: str, headers: Headers, total_content_length: int
) -> t.Tuple[str, t.BinaryIO]:
if isinstance(filename, bytes):
filename = filename.decode(self.charset, self.errors)

filename = self._fix_ie_filename(filename)
content_type = headers.get("content-type")
self, event: File, total_content_length: int
) -> t.BinaryIO:
content_type = event.headers.get("content-type")

try:
content_length = int(headers["content-length"])
content_length = int(event.headers["content-length"])
except (KeyError, ValueError):
content_length = 0

container = self.stream_factory(
total_content_length=total_content_length,
filename=filename,
filename=event.filename,
content_type=content_type,
content_length=content_length,
)
return filename, container
return container

def in_memory_threshold_reached(self, size: int) -> None:
raise exceptions.RequestEntityTooLarge()
def parse(
self, stream: t.BinaryIO, boundary: bytes, content_length: int
) -> t.Tuple[MultiDict, MultiDict]:
container: t.Union[t.BinaryIO, t.List[bytes]]
_write: t.Callable[[bytes], t.Any]

def parse_lines(
self,
file: t.BinaryIO,
boundary: bytes,
content_length: int,
cap_at_buffer: bool = True,
) -> t.Iterator[
t.Tuple[
str, t.Union[t.Tuple[Headers, str], t.Tuple[Headers, str, str], bytes, None]
]
]:
"""Generate parts of
``('begin_form', (headers, name))``
``('begin_file', (headers, name, filename))``
``('cont', bytes)``
``('end', None)``
Always obeys the grammar
parts = ( begin_form cont* end |
begin_file cont* end )*
"""
next_part = b"--" + boundary
last_part = next_part + b"--"
iterator = chain(
make_line_iter(
file,
_make_chunk_iter(
stream,
limit=content_length,
buffer_size=self.buffer_size,
cap_at_buffer=cap_at_buffer,
),
repeat(b""),
)
terminator = self._find_terminator(iterator)

if terminator == last_part:
return
elif terminator != next_part:
self.fail("Expected boundary at start of multipart data")

while terminator != last_part:
headers = parse_multipart_headers(iterator)
disposition = headers.get("content-disposition")

if disposition is None:
self.fail("Missing Content-Disposition header")

disposition, extra = parse_options_header(disposition)
transfer_encoding = self.get_part_encoding(headers)
name = t.cast(str, extra.get("name"))
filename = extra.get("filename")

# if no content type is given we stream into memory. A list is
# used as a temporary container.
if filename is None:
yield _begin_form, (headers, name)
# otherwise we parse the rest of the headers and ask the stream
# factory for something we can write in.
else:
yield _begin_file, (headers, name, filename)

buf = b""

for line in iterator:
if not line:
self.fail("unexpected end of stream")

if line[:2] == b"--":
terminator = line.rstrip()

if terminator in {next_part, last_part}:
break

if transfer_encoding is not None:
if transfer_encoding == "base64":
transfer_encoding = "base64_codec"

try:
line = codecs.decode(line, transfer_encoding) # type: ignore
except Exception:
self.fail("could not decode transfer encoded chunk")

# we have something in the buffer from the last iteration.
# this is usually a newline delimiter.
if buf:
yield _cont, buf

# If the line ends with windows CRLF we write everything except
# the last two bytes. In all other cases however we write
# everything except the last byte. If it was a newline, that's
# fine, otherwise it does not matter because we will write it
# the next iteration. this ensures we do not write the
# final newline into the stream. That way we do not have to
# truncate the stream. However we do have to make sure that
# if something else than a newline is in there we write it
# out.
if line[-2:] == b"\r\n":
buf = b"\r\n"
cutoff = -2
else:
buf = line[-1:]
cutoff = -1

yield _cont, line[:cutoff]
else:
raise ValueError("unexpected end of part")

# if we have a leftover in the buffer that is not a newline
# character we have to flush it, otherwise we will chop of
# certain values.
if buf not in {b"", b"\r", b"\n", b"\r\n"}:
yield _cont, buf

yield _end, None

def parse_parts(
self, file: t.BinaryIO, boundary: bytes, content_length: int
) -> t.Iterator[t.Tuple[str, t.Tuple[str, t.Union[str, FileStorage]]]]:
"""Generate ``('file', (name, val))`` and
``('form', (name, val))`` parts.
"""
in_memory = 0
guard_memory: bool
is_file: bool
container: t.Union[t.BinaryIO, t.List[bytes]]
_write: t.Callable[[bytes], t.Any]
headers: Headers
name: str
filename: str

for ellt, ell in self.parse_lines(file, boundary, content_length):
if ellt == _begin_file:
headers, name, filename = t.cast(t.Tuple[Headers, str, str], ell)
is_file = True
guard_memory = False
filename, container = self.start_file_streaming(
filename, headers, content_length
)
_write = container.write

elif ellt == _begin_form:
headers, name = t.cast(t.Tuple[Headers, str], ell)
is_file = False
container = []
_write = container.append
guard_memory = self.max_form_memory_size is not None

elif ellt == _cont:
ell = t.cast(bytes, ell)
_write(ell)
# if we write into memory and there is a memory size limit we
# count the number of bytes in memory and raise an exception if
# there is too much data in memory.
if guard_memory:
in_memory += len(ell)

if in_memory > self.max_form_memory_size: # type: ignore
self.in_memory_threshold_reached(in_memory)

elif ellt == _end:
if is_file:
container = t.cast(t.BinaryIO, container)
container.seek(0)
yield (
"file",
(name, FileStorage(container, filename, name, headers=headers)),
)
else:
part_charset = self.get_part_charset(headers)
yield (
"form",
(name, b"".join(container).decode(part_charset, self.errors)),
)

def parse(
self, file: t.BinaryIO, boundary: bytes, content_length: int
) -> t.Tuple[MultiDict, MultiDict]:
form_stream, file_stream = tee(
self.parse_parts(file, boundary, content_length), 2
[None],
)
form = (v for t, v in form_stream if t == "form")
files = (v for t, v in file_stream if t == "file")
return self.cls(form), self.cls(files)
parser = MultipartDecoder(boundary, self.max_form_memory_size)

fields = []
files = []

current_part: Union[Field, File]
for data in iterator:
parser.receive_data(data)
event = parser.next_event()
while not isinstance(event, (Epilogue, NeedData)):
if isinstance(event, Field):
current_part = event
container = []
_write = container.append
elif isinstance(event, File):
current_part = event
container = self.start_file_streaming(event, content_length)
_write = container.write
elif isinstance(event, Data):
_write(event.data)
if not event.more_data:
if isinstance(current_part, Field):
value = b"".join(container).decode(
self.get_part_charset(current_part.headers), self.errors
)
fields.append((current_part.name, value))
else:
container = t.cast(t.BinaryIO, container)
container.seek(0)
files.append(
(
current_part.name,
FileStorage(
container,
current_part.filename,
current_part.name,
headers=current_part.headers,
),
)
)

event = parser.next_event()

return self.cls(fields), self.cls(files)
Loading

0 comments on commit 90f54f3

Please sign in to comment.