diff --git a/tests/test_extensions.py b/tests/test_extensions.py index b154b48a..d32f3bd3 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING from unittest import mock +from django.contrib.auth.models import User from rest_framework import fields, mixins, pagination, permissions, serializers, viewsets from rest_framework.authentication import BaseAuthentication from rest_framework.decorators import api_view @@ -15,7 +16,7 @@ ResolvedComponent, build_array_type, build_basic_type, build_object_type, force_instance, ) from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import Direction, extend_schema, extend_schema_field +from drf_spectacular.utils import Direction, extend_schema, extend_schema_field, extend_schema_view from tests import generate_schema, get_response_schema from tests.models import SimpleModel, SimpleSerializer @@ -284,7 +285,7 @@ def __init__(self, serializer_class, pagination_class, **kwargs): super().__init__(**kwargs) class PaginationWrapperExtension(OpenApiSerializerExtension): - target_class = PaginationWrapper + target_class = PaginationWrapper # this can also be an import string def get_name(self, auto_schema, direction): return auto_schema.get_paginated_name( @@ -326,3 +327,66 @@ class XViewset(viewsets.ModelViewSet): assert schema['components']['schemas']['PaginatedSimpleList']['properties']['results'] == { '$ref': '#/components/schemas/Simple' } + + +def test_serializer_with_dynamic_fields(no_warnings): + class DynamicFieldsModelSerializer(serializers.ModelSerializer): + """ + A ModelSerializer that takes an additional `fields` argument that + controls which fields should be displayed. + + Taken from (only added ref_name) + https://www.django-rest-framework.org/api-guide/serializers/#dynamically-modifying-fields + """ + def __init__(self, *args, **kwargs): + # Don't pass the 'fields' arg up to the superclass + fields = kwargs.pop('fields', None) + self.ref_name = kwargs.pop('ref_name', None) # only change to original version! + + # Instantiate the superclass normally + super().__init__(*args, **kwargs) + + if fields is not None: + # Drop any fields that are not specified in the `fields` argument. + allowed = set(fields) + existing = set(self.fields) + for field_name in existing - allowed: + self.fields.pop(field_name) + + class UserSerializer(DynamicFieldsModelSerializer): + class Meta: + model = User + fields = ['id', 'username', 'email'] + + class DynamicFieldsModelSerializerExtension(OpenApiSerializerExtension): + target_class = DynamicFieldsModelSerializer # this can also be an import string + match_subclasses = True + + def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction): + return auto_schema._map_serializer(self.target, direction, bypass_extensions=True) + + def get_name(self, auto_schema, direction): + return self.target.ref_name + + @extend_schema_view( + list=extend_schema(responses=UserSerializer(fields=['id'], ref_name='CompactUser')) + ) + class XViewset(viewsets.ModelViewSet): + serializer_class = UserSerializer + queryset = User.objects.none() + + schema = generate_schema('x', XViewset) + + assert schema['components']['schemas']['User']['properties'] == { + 'id': {'type': 'integer', 'readOnly': True}, + 'username': { + 'type': 'string', + 'description': 'Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.', + 'pattern': '^[\\w.@+-]+$', + 'maxLength': 150 + }, + 'email': {'type': 'string', 'format': 'email', 'title': 'Email address', 'maxLength': 254} + } + assert schema['components']['schemas']['CompactUser']['properties'] == { + 'id': {'type': 'integer', 'readOnly': True} + }