Skip to content

Commit

Permalink
Response headers support
Browse files Browse the repository at this point in the history
  • Loading branch information
p1c2u committed May 19, 2021
1 parent c4fab4c commit d06d03f
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 58 deletions.
27 changes: 27 additions & 0 deletions openapi_core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,33 @@ class OpenAPIError(Exception):
pass


class OpenAPIHeaderError(OpenAPIError):
pass


class MissingHeaderError(OpenAPIHeaderError):
"""Missing header error"""
pass


@attr.s(hash=True)
class MissingHeader(MissingHeaderError):
name = attr.ib()

def __str__(self):
return "Missing header (without default value): {0}".format(
self.name)


@attr.s(hash=True)
class MissingRequiredHeader(MissingHeaderError):
name = attr.ib()

def __str__(self):
return "Missing required header: {0}".format(self.name)



class OpenAPIParameterError(OpenAPIError):
pass

Expand Down
50 changes: 35 additions & 15 deletions openapi_core/schema/parameters.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,54 @@
from __future__ import division


def get_aslist(param):
"""Checks if parameter is described as list for simpler scenarios"""
def get_aslist(param_or_header):
"""Checks if parameter/header is described as list for simpler scenarios"""
# if schema is not defined it's a complex scenario
if 'schema' not in param:
if 'schema' not in param_or_header:
return False

param_schema = param / 'schema'
schema_type = param_schema.getkey('type', 'any')
schema = param_or_header / 'schema'
schema_type = schema.getkey('type', 'any')
# TODO: resolve for 'any' schema type
return schema_type in ['array', 'object']


def get_style(param):
"""Checks parameter style for simpler scenarios"""
if 'style' in param:
return param['style']
def get_style(param_or_header):
"""Checks parameter/header style for simpler scenarios"""
if 'style' in param_or_header:
return param_or_header['style']

# if "in" not defined then it's a Header
location = param_or_header.getkey('in', 'header')

# determine default
return (
'simple' if param['in'] in ['path', 'header'] else 'form'
'simple' if location in ['path', 'header'] else 'form'
)


def get_explode(param):
"""Checks parameter explode for simpler scenarios"""
if 'explode' in param:
return param['explode']
def get_explode(param_or_header):
"""Checks parameter/header explode for simpler scenarios"""
if 'explode' in param_or_header:
return param_or_header['explode']

# determine default
style = get_style(param)
style = get_style(param_or_header)
return style == 'form'


def get_value(param_or_header, location, name=None):
"""Returns parameter/header value from specific location"""
name = name or param_or_header['name']

if name not in location:
raise KeyError

aslist = get_aslist(param_or_header)
explode = get_explode(param_or_header)
if aslist and explode:
if hasattr(location, 'getall'):
return location.getall(name)
return location.getlist(name)

return location[name]
5 changes: 4 additions & 1 deletion openapi_core/testing/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
class MockResponseFactory(object):

@classmethod
def create(cls, data, status_code=200, mimetype='application/json'):
def create(
cls, data, status_code=200, headers=None,
mimetype='application/json'):
return OpenAPIResponse(
data=data,
status_code=status_code,
headers=headers or {},
mimetype=mimetype,
)
58 changes: 23 additions & 35 deletions openapi_core/validation/request/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

from openapi_core.casting.schemas.exceptions import CastError
from openapi_core.deserializing.exceptions import DeserializeError
from openapi_core.deserializing.parameters.factories import (
ParameterDeserializersFactory,
)
from openapi_core.exceptions import (
MissingRequiredParameter, MissingParameter,
MissingRequiredRequestBody, MissingRequestBody,
Expand Down Expand Up @@ -46,10 +43,6 @@ def schema_unmarshallers_factory(self):
def security_provider_factory(self):
return SecurityProviderFactory()

@property
def parameter_deserializers_factory(self):
return ParameterDeserializersFactory()

def validate(self, request):
try:
path, operation, _, path_result, _ = self._find_path(request)
Expand Down Expand Up @@ -177,35 +170,23 @@ def _get_parameters(self, request, params):
return RequestParameters(**locations), errors

def _get_parameter(self, param, request):
if param.getkey('deprecated', False):
name = param['name']
deprecated = param.getkey('deprecated', False)
if deprecated:
warnings.warn(
"{0} parameter is deprecated".format(param['name']),
"{0} parameter is deprecated".format(name),
DeprecationWarning,
)

param_location = param['in']
location = request.parameters[param_location]
try:
raw_value = self._get_parameter_value(param, request)
except MissingParameter:
if 'schema' not in param:
raise
schema = param / 'schema'
if 'default' not in schema:
raise
casted = schema['default']
else:
# Simple scenario
if 'content' not in param:
deserialised = self._deserialise_parameter(param, raw_value)
schema = param / 'schema'
# Complex scenario
else:
content = param / 'content'
mimetype, media_type = next(content.items())
deserialised = self._deserialise_data(mimetype, raw_value)
schema = media_type / 'schema'
casted = self._cast(schema, deserialised)
unmarshalled = self._unmarshal(schema, casted)
return unmarshalled
return self._get_param_or_header_value(param, location)
except KeyError:
required = param.getkey('required', False)
if required:
raise MissingRequiredParameter(name)
raise MissingParameter(name)

def _get_body(self, request, operation):
if 'requestBody' not in operation:
Expand Down Expand Up @@ -274,13 +255,20 @@ def _get_parameter_value(self, param, request):

return location[param['name']]

def _get_location_value(self, name, param, request):
name = param['name']
param_location = param['in']
location = request.parameters[param_location]
try:
return get_location_value(location, name)
except KeyError:
if header.getkey('required', False):
raise MissingRequiredParameter(name)
raise MissingParameter(name)

def _get_body_value(self, request_body, request):
if not request.body:
if request_body.getkey('required', False):
raise MissingRequiredRequestBody(request)
raise MissingRequestBody(request)
return request.body

def _deserialise_parameter(self, param, value):
deserializer = self.parameter_deserializers_factory.create(param)
return deserializer(value)
6 changes: 4 additions & 2 deletions openapi_core/validation/response/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""OpenAPI core validation response datatypes module"""
import attr
from werkzeug.datastructures import Headers

from openapi_core.validation.datatypes import BaseValidationResult

Expand All @@ -13,14 +14,15 @@ class OpenAPIResponse(object):
The response body, as string.
status_code
The status code as integer.
headers
Response headers as Headers.
mimetype
Lowercase content type without charset.
"""

data = attr.ib()
status_code = attr.ib()

mimetype = attr.ib()
headers = attr.ib(factory=Headers, converter=Headers)


@attr.s
Expand Down
48 changes: 43 additions & 5 deletions openapi_core/validation/response/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from openapi_core.casting.schemas.exceptions import CastError
from openapi_core.deserializing.exceptions import DeserializeError
from openapi_core.exceptions import MissingResponseContent
from openapi_core.exceptions import (
MissingHeader, MissingRequiredHeader, MissingResponseContent,
)
from openapi_core.templating.media_types.exceptions import MediaTypeFinderError
from openapi_core.templating.paths.exceptions import PathError
from openapi_core.templating.responses.exceptions import ResponseFinderError
Expand Down Expand Up @@ -117,12 +119,48 @@ def _get_data(self, response, operation_response):
return data, []

def _get_headers(self, response, operation_response):
errors = []
if 'headers' not in operation_response:
return {}, []

# @todo: implement
headers = {}
headers = operation_response / 'headers'

return headers, errors
errors = []
validated = {}
for name, header in headers.items():
# ignore Content-Type header
if name == "Content-Type":
continue
try:
value = self._get_header(name, header, response)
except MissingHeader:
continue
except (
MissingRequiredHeader, DeserializeError,
CastError, ValidateError, UnmarshalError,
) as exc:
errors.append(exc)
continue
else:
validated[name] = value

return validated, errors

def _get_header(self, name, header, response):
deprecated = header.getkey('deprecated', False)
if deprecated:
warnings.warn(
"{0} header is deprecated".format(name),
DeprecationWarning,
)

try:
return self._get_param_or_header_value(
header, response.headers, name=name)
except KeyError:
required = header.getkey('required', False)
if required:
raise MissingRequiredHeader(name)
raise MissingHeader(name)

def _get_data_value(self, response):
if not response.data:
Expand Down
39 changes: 39 additions & 0 deletions openapi_core/validation/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from openapi_core.deserializing.media_types.factories import (
MediaTypeDeserializersFactory,
)
from openapi_core.deserializing.parameters.factories import (
ParameterDeserializersFactory,
)
from openapi_core.schema.parameters import get_value
from openapi_core.templating.paths.finders import PathFinder
from openapi_core.unmarshalling.schemas.util import build_format_checker

Expand Down Expand Up @@ -36,6 +40,10 @@ def media_type_deserializers_factory(self):
return MediaTypeDeserializersFactory(
self.custom_media_type_deserializers)

@property
def parameter_deserializers_factory(self):
return ParameterDeserializersFactory()

@property
def schema_unmarshallers_factory(self):
raise NotImplementedError
Expand All @@ -52,10 +60,41 @@ def _deserialise_data(self, mimetype, value):
deserializer = self.media_type_deserializers_factory.create(mimetype)
return deserializer(value)

def _deserialise_parameter(self, param, value):
deserializer = self.parameter_deserializers_factory.create(param)
return deserializer(value)

def _cast(self, schema, value):
caster = self.schema_casters_factory.create(schema)
return caster(value)

def _unmarshal(self, schema, value):
unmarshaller = self.schema_unmarshallers_factory.create(schema)
return unmarshaller(value)

def _get_param_or_header_value(self, param_or_header, location, name=None):
required = param_or_header.getkey('required', False)
try:
raw_value = get_value(param_or_header, location, name=name)
except KeyError:
if 'schema' not in param_or_header:
raise
schema = param_or_header / 'schema'
if 'default' not in schema:
raise
casted = schema['default']
else:
# Simple scenario
if 'content' not in param_or_header:
deserialised = self._deserialise_parameter(
param_or_header, raw_value)
schema = param_or_header / 'schema'
# Complex scenario
else:
content = param_or_header / 'content'
mimetype, media_type = next(content.items())
deserialised = self._deserialise_data(mimetype, raw_value)
schema = media_type / 'schema'
casted = self._cast(schema, deserialised)
unmarshalled = self._unmarshal(schema, casted)
return unmarshalled

0 comments on commit d06d03f

Please sign in to comment.