Skip to content

Commit

Permalink
add exclusion for discovered parameters #212
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Dec 19, 2020
1 parent 469c484 commit 283d630
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 14 deletions.
32 changes: 20 additions & 12 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def get_override_parameters(self):

def _process_override_parameters(self):
result = []
exclusions = []
for parameter in self.get_override_parameters():
if isinstance(parameter, OpenApiParameter):
if is_basic_type(parameter.type):
Expand All @@ -130,16 +131,20 @@ def _process_override_parameters(self):
schema = self.resolve_serializer(parameter.type, 'request').ref
else:
schema = parameter.type
result.append(build_parameter_type(
name=parameter.name,
schema=schema,
location=parameter.location,
required=parameter.required,
description=parameter.description,
enum=parameter.enum,
deprecated=parameter.deprecated,
examples=build_examples_list(parameter.examples),
))

if parameter.exclude:
exclusions.append((parameter.name, parameter.location))
else:
result.append(build_parameter_type(
name=parameter.name,
schema=schema,
location=parameter.location,
required=parameter.required,
description=parameter.description,
enum=parameter.enum,
deprecated=parameter.deprecated,
examples=build_examples_list(parameter.examples),
))
elif is_serializer(parameter):
# explode serializer into separate parameters. defaults to QUERY location
mapped = self._map_serializer(parameter, 'request')
Expand All @@ -152,7 +157,7 @@ def _process_override_parameters(self):
))
else:
warn(f'could not resolve parameter annotation {parameter}. skipping.')
return result
return result, exclusions

def _get_format_parameters(self):
parameters = []
Expand All @@ -170,7 +175,8 @@ def _get_parameters(self):
def dict_helper(parameters):
return {(p['name'], p['in']): p for p in parameters}

override_parameters = dict_helper(self._process_override_parameters())
override_parameters, excluded_parameters = self._process_override_parameters()
override_parameters = dict_helper(override_parameters)
# remove overridden path parameters beforehand so that there are no irrelevant warnings.
path_variables = [
v for v in uritemplate.variables(self.path) if (v, 'path') not in override_parameters
Expand All @@ -182,6 +188,8 @@ def dict_helper(parameters):
**dict_helper(self._get_format_parameters()),
}
# override/add @extend_schema parameters
for key in excluded_parameters:
del parameters[key]
for key, parameter in override_parameters.items():
parameters[key] = parameter

Expand Down
4 changes: 3 additions & 1 deletion drf_spectacular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def __init__(
description='',
enum=None,
deprecated=False,
examples: Optional[List[OpenApiExample]] = None
examples: Optional[List[OpenApiExample]] = None,
exclude=False,
):
self.name = name
self.type = type
Expand All @@ -118,6 +119,7 @@ def __init__(
self.enum = enum
self.deprecated = deprecated
self.examples = examples or []
self.exclude = exclude


def extend_schema(
Expand Down
28 changes: 27 additions & 1 deletion tests/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import (
OpenApiParameter, extend_schema, extend_schema_field, extend_schema_serializer,
inline_serializer,
extend_schema_view, inline_serializer,
)
from drf_spectacular.validation import validate_schema
from tests import generate_schema, get_request_schema, get_response_schema
Expand Down Expand Up @@ -1139,3 +1139,29 @@ def view_func(request, format=None):

schema = generate_schema('x', view_function=view_func)
assert get_response_schema(schema['paths']['/x']['get'])['type'] == 'string'


def test_exclude_discovered_parameter(no_warnings):
class M8(models.Model):
pass

class XSerializer(serializers.ModelSerializer):
class Meta:
fields = '__all__'
model = M8

@extend_schema_view(list=extend_schema(parameters=[
# keep 'offset', remove 'limit', and add 'random'
OpenApiParameter('limit', exclude=True),
OpenApiParameter('random', bool),
]))
class XViewset(viewsets.ReadOnlyModelViewSet):
queryset = M8.objects.all()
serializer_class = XSerializer
pagination_class = pagination.LimitOffsetPagination

schema = generate_schema('x', XViewset)
parameters = schema['paths']['/x/']['get']['parameters']
assert len(parameters) == 2
assert parameters[0]['name'] == 'offset'
assert parameters[1]['name'] == 'random'

0 comments on commit 283d630

Please sign in to comment.