Skip to content

Commit

Permalink
Polish form validation (#1604)
Browse files Browse the repository at this point in the history
* Create MediaTypeDict class for range matching

* Extract parsing instead of using jsonifier

* Add default validator class as parameter defaults

* Clean up form validation
  • Loading branch information
RobbeSneyders committed Nov 4, 2022
1 parent 9d7258c commit a1b1f53
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 67 deletions.
40 changes: 22 additions & 18 deletions connexion/middleware/request_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,28 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
# TODO: Validate parameters

# Validate body
try:
body_validator = self._validator_map["body"][mime_type] # type: ignore
except KeyError:
logging.info(
f"Skipping validation. No validator registered for content type: "
f"{mime_type}."
)
else:
validator = body_validator(
scope,
receive,
schema=self._operation.body_schema(mime_type),
nullable=utils.is_nullable(self._operation.body_definition(mime_type)),
encoding=encoding,
strict_validation=self.strict_validation,
uri_parser=self._operation._uri_parsing_decorator,
)
receive_fn = await validator.wrapped_receive()
schema = self._operation.body_schema(mime_type)
if schema:
try:
body_validator = self._validator_map["body"][mime_type] # type: ignore
except KeyError:
logging.info(
f"Skipping validation. No validator registered for content type: "
f"{mime_type}."
)
else:
validator = body_validator(
scope,
receive,
schema=schema,
nullable=utils.is_nullable(
self._operation.body_definition(mime_type)
),
encoding=encoding,
strict_validation=self.strict_validation,
uri_parser=self._operation._uri_parsing_decorator,
)
receive_fn = await validator.wrapped_receive()

await self.next_app(scope, receive_fn, send)

Expand Down
96 changes: 47 additions & 49 deletions connexion/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
receive: Receive,
*,
schema: dict,
validator: t.Type[Draft4Validator] = None,
validator: t.Type[Draft4Validator] = Draft4RequestValidator,
nullable=False,
encoding: str,
**kwargs,
Expand All @@ -47,8 +47,7 @@ def __init__(
self.schema = schema
self.has_default = schema.get("default", False)
self.nullable = nullable
validator_cls = validator or Draft4RequestValidator
self.validator = validator_cls(schema, format_checker=draft4_format_checker)
self.validator = validator(schema, format_checker=draft4_format_checker)
self.encoding = encoding
self._messages: t.List[t.MutableMapping[str, t.Any]] = []

Expand All @@ -69,23 +68,25 @@ def validate(self, body: dict):
)
raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")

@staticmethod
def parse(body: str) -> dict:
try:
return json.loads(body)
except json.decoder.JSONDecodeError as e:
raise BadRequestProblem(str(e))

async def wrapped_receive(self) -> Receive:
more_body = True
while more_body:
message = await self._receive()
self._messages.append(message)
more_body = message.get("more_body", False)

# TODO: make json library pluggable
bytes_body = b"".join([message.get("body", b"") for message in self._messages])
decoded_body = bytes_body.decode(self.encoding)

if decoded_body and not (self.nullable and is_null(decoded_body)):
try:
body = json.loads(decoded_body)
except json.decoder.JSONDecodeError as e:
raise BadRequestProblem(str(e))

body = self.parse(decoded_body)
self.validate(body)

async def receive() -> t.MutableMapping[str, t.Any]:
Expand All @@ -105,7 +106,7 @@ def __init__(
send: Send,
*,
schema: dict,
validator: t.Type[Draft4Validator] = None,
validator: t.Type[Draft4Validator] = Draft4ResponseValidator,
nullable=False,
encoding: str,
) -> None:
Expand All @@ -114,8 +115,7 @@ def __init__(
self.schema = schema
self.has_default = schema.get("default", False)
self.nullable = nullable
validator_cls = validator or Draft4ResponseValidator
self.validator = validator_cls(schema, format_checker=draft4_format_checker)
self.validator = validator(schema, format_checker=draft4_format_checker)
self.encoding = encoding
self._messages: t.List[t.MutableMapping[str, t.Any]] = []

Expand Down Expand Up @@ -151,7 +151,6 @@ async def send(self, message: t.MutableMapping[str, t.Any]) -> None:
if message["type"] == "http.response.start" or message.get("more_body", False):
return

# TODO: make json library pluggable
bytes_body = b"".join([message.get("body", b"") for message in self._messages])
decoded_body = bytes_body.decode(self.encoding)

Expand Down Expand Up @@ -238,44 +237,43 @@ def validate(self, data: FormData) -> None:
if errors:
raise ExtraParameterProblem(errors, [])

if data:
props = self.schema.get("properties", {})
errs = []
if self.uri_parser is not None:
# TODO: Make more efficient
# Flask splits up file uploads and text input in `files` and `form`,
# while starlette puts them both in `form`
form_keys = {k for k, v in data.items() if isinstance(v, str)}
file_data = {k: v for k, v in data.items() if isinstance(v, UploadFile)}
data = {k: data.getlist(k) for k in form_keys}
data = self.uri_parser.resolve_form(data)
# Add the files again
data.update(file_data)
else:
data = dict(data) # TODO: preserve multi-item?
for k, param_defn in props.items():
if k in data:
if param_defn.get("format", "") == "binary":
# Replace files with empty strings for validation
data[k] = ""
continue

try:
data[k] = coerce_type(param_defn, data[k], "requestBody", k)
except TypeValidationError as e:
logger.exception(e)
errs += [str(e)]
if errs:
raise BadRequestProblem(detail=errs)
props = self.schema.get("properties", {})
errs = []
if self.uri_parser is not None:
# Don't parse file_data
form_data = {}
file_data = {}
for k, v in data.items():
if isinstance(v, str):
form_data[k] = data.getlist(k)
elif isinstance(v, UploadFile):
file_data[k] = data.getlist(k)

data = self.uri_parser.resolve_form(form_data)
# Add the files again
data.update(file_data)
else:
data = {k: data.getlist(k) for k in data}

for k, param_defn in props.items():
if k in data:
if param_defn.get("format", "") == "binary":
# Replace files with empty strings for validation
data[k] = ""
continue

try:
data[k] = coerce_type(param_defn, data[k], "requestBody", k)
except TypeValidationError as e:
logger.exception(e)
errs += [str(e)]

if errs:
raise BadRequestProblem(detail=errs)

self._validate(data)

async def wrapped_receive(self) -> Receive:

if not self.schema:
# swagger 2
return self._receive

async def stream() -> t.AsyncGenerator[bytes, None]:
more_body = True
while more_body:
Expand All @@ -288,8 +286,8 @@ async def stream() -> t.AsyncGenerator[bytes, None]:
form_parser = self.form_parser_cls(self.headers, stream())
form = await form_parser.parse()

if not (self.nullable and is_null(form)):
self.validate(form or {})
if form and not (self.nullable and is_null(form)):
self.validate(form)

async def receive() -> t.MutableMapping[str, t.Any]:
while self._messages:
Expand Down

0 comments on commit a1b1f53

Please sign in to comment.