diff --git a/drf_spectacular/openapi.py b/drf_spectacular/openapi.py index 09e99264..c9de2804 100644 --- a/drf_spectacular/openapi.py +++ b/drf_spectacular/openapi.py @@ -122,6 +122,7 @@ def get_override_parameters(self): def _process_override_parameters(self): result = [] + exclusions = [] for parameter in self.get_override_parameters(): if isinstance(parameter, OpenApiParameter): if is_basic_type(parameter.type): @@ -130,16 +131,20 @@ def _process_override_parameters(self): schema = self.resolve_serializer(parameter.type, 'request').ref else: schema = parameter.type - result.append(build_parameter_type( - name=parameter.name, - schema=schema, - location=parameter.location, - required=parameter.required, - description=parameter.description, - enum=parameter.enum, - deprecated=parameter.deprecated, - examples=build_examples_list(parameter.examples), - )) + + if parameter.exclude: + exclusions.append((parameter.name, parameter.location)) + else: + result.append(build_parameter_type( + name=parameter.name, + schema=schema, + location=parameter.location, + required=parameter.required, + description=parameter.description, + enum=parameter.enum, + deprecated=parameter.deprecated, + examples=build_examples_list(parameter.examples), + )) elif is_serializer(parameter): # explode serializer into separate parameters. defaults to QUERY location mapped = self._map_serializer(parameter, 'request') @@ -152,7 +157,7 @@ def _process_override_parameters(self): )) else: warn(f'could not resolve parameter annotation {parameter}. skipping.') - return result + return result, exclusions def _get_format_parameters(self): parameters = [] @@ -170,7 +175,8 @@ def _get_parameters(self): def dict_helper(parameters): return {(p['name'], p['in']): p for p in parameters} - override_parameters = dict_helper(self._process_override_parameters()) + override_parameters, excluded_parameters = self._process_override_parameters() + override_parameters = dict_helper(override_parameters) # remove overridden path parameters beforehand so that there are no irrelevant warnings. path_variables = [ v for v in uritemplate.variables(self.path) if (v, 'path') not in override_parameters @@ -182,6 +188,8 @@ def dict_helper(parameters): **dict_helper(self._get_format_parameters()), } # override/add @extend_schema parameters + for key in excluded_parameters: + del parameters[key] for key, parameter in override_parameters.items(): parameters[key] = parameter diff --git a/drf_spectacular/utils.py b/drf_spectacular/utils.py index dc14a9b9..4df26c5c 100644 --- a/drf_spectacular/utils.py +++ b/drf_spectacular/utils.py @@ -108,7 +108,8 @@ def __init__( description='', enum=None, deprecated=False, - examples: Optional[List[OpenApiExample]] = None + examples: Optional[List[OpenApiExample]] = None, + exclude=False, ): self.name = name self.type = type @@ -118,6 +119,7 @@ def __init__( self.enum = enum self.deprecated = deprecated self.examples = examples or [] + self.exclude = exclude def extend_schema( diff --git a/tests/test_regressions.py b/tests/test_regressions.py index 60ba2de0..06b23657 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -19,7 +19,7 @@ from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import ( OpenApiParameter, extend_schema, extend_schema_field, extend_schema_serializer, - inline_serializer, + extend_schema_view, inline_serializer, ) from drf_spectacular.validation import validate_schema from tests import generate_schema, get_request_schema, get_response_schema @@ -1139,3 +1139,29 @@ def view_func(request, format=None): schema = generate_schema('x', view_function=view_func) assert get_response_schema(schema['paths']['/x']['get'])['type'] == 'string' + + +def test_exclude_discovered_parameter(no_warnings): + class M8(models.Model): + pass + + class XSerializer(serializers.ModelSerializer): + class Meta: + fields = '__all__' + model = M8 + + @extend_schema_view(list=extend_schema(parameters=[ + # keep 'offset', remove 'limit', and add 'random' + OpenApiParameter('limit', exclude=True), + OpenApiParameter('random', bool), + ])) + class XViewset(viewsets.ReadOnlyModelViewSet): + queryset = M8.objects.all() + serializer_class = XSerializer + pagination_class = pagination.LimitOffsetPagination + + schema = generate_schema('x', XViewset) + parameters = schema['paths']['/x/']['get']['parameters'] + assert len(parameters) == 2 + assert parameters[0]['name'] == 'offset' + assert parameters[1]['name'] == 'random'