diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index 800cddc5..62283da3 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -391,11 +391,11 @@ def _follow_field_source(model, path: List[str]): else: if isinstance(field_or_property, (property, cached_property)) or callable(field_or_property): if isinstance(field_or_property, property): - target_model = typing.get_type_hints(field_or_property.fget).get('return') + target_model = _follow_return_type(field_or_property.fget) elif isinstance(field_or_property, cached_property): - target_model = typing.get_type_hints(field_or_property.func).get('return') + target_model = _follow_return_type(field_or_property.func) else: - target_model = typing.get_type_hints(field_or_property).get('return') + target_model = _follow_return_type(field_or_property) if not target_model: raise UnableToProceedError( f'could not follow field source through intermediate property "{path[0]}" ' @@ -408,6 +408,25 @@ def _follow_field_source(model, path: List[str]): return _follow_field_source(target_model, path[1:]) +def _follow_return_type(a_callable): + target_type = typing.get_type_hints(a_callable).get('return') + if target_type is None: + return target_type + origin, args = _get_type_hint_origin(target_type) + if origin is typing.Union: + type_args = [arg for arg in args if arg is not type(None)] # noqa: E721 + if len(type_args) > 1: + warn( + f'could not traverse Union type, because we don\'t know which type to choose ' + f'from {type_args}. Consider terminating "source" on a custom property ' + f'that indicates the expected Optional/Union type. Defaulting to "string"' + ) + return target_type + # Optional: + return type_args[0] + return target_type + + def follow_field_source(model, path): """ a model traversal chain "foreignkey.foreignkey.value" can either end with an actual model field diff --git a/tests/test_fields.py b/tests/test_fields.py index 21009266..fd899f82 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -2,6 +2,7 @@ import uuid from datetime import date, datetime, timedelta from decimal import Decimal +from typing import Optional import pytest from django import __version__ as DJANGO_VERSION @@ -48,6 +49,10 @@ def nested(self) -> 'SubObject': def model_instance(self) -> 'AllFields': return self._instance + @property + def optional_int(self) -> Optional[int]: + return 1 + class AllFields(models.Model): # basics @@ -127,6 +132,10 @@ def sub_object(self) -> SubObject: def sub_object_cached(self) -> SubObject: return SubObject(self) + @property + def optional_sub_object(self) -> Optional[SubObject]: + return SubObject(self) + class AllFieldsSerializer(serializers.ModelSerializer): field_decimal_uncoerced = serializers.DecimalField( @@ -204,6 +213,16 @@ def get_field_method_object(self, obj) -> dict: field_sub_object_cached_nested_calculated = serializers.ReadOnlyField(source='sub_object_cached.nested.calculated') field_sub_object_cached_model_int = serializers.ReadOnlyField(source='sub_object_cached.model_instance.field_int') + # typing.Optional + field_optional_sub_object_calculated = serializers.ReadOnlyField( + source='optional_sub_object.calculated', + allow_null=True, + ) + field_sub_object_optional_int = serializers.ReadOnlyField( + source='sub_object.optional_int', + allow_null=True, + ) + class Meta: fields = '__all__' model = AllFields diff --git a/tests/test_fields.yml b/tests/test_fields.yml index 0ecb6013..4c478a7e 100644 --- a/tests/test_fields.yml +++ b/tests/test_fields.yml @@ -192,6 +192,14 @@ components: field_sub_object_cached_model_int: type: integer readOnly: true + field_optional_sub_object_calculated: + type: integer + readOnly: true + nullable: true + field_sub_object_optional_int: + type: integer + nullable: true + readOnly: true field_int: type: integer field_float: @@ -300,6 +308,7 @@ components: - field_model_cached_property_float - field_model_property_float - field_o2o + - field_optional_sub_object_calculated - field_posint - field_possmallint - field_read_only_model_function_basic @@ -318,6 +327,7 @@ components: - field_sub_object_calculated - field_sub_object_model_int - field_sub_object_nested_calculated + - field_sub_object_optional_int - field_text - field_time - field_url diff --git a/tests/test_warnings.py b/tests/test_warnings.py index 34954aec..0af636c7 100644 --- a/tests/test_warnings.py +++ b/tests/test_warnings.py @@ -1,3 +1,4 @@ +from typing import Union from unittest import mock import pytest @@ -209,6 +210,37 @@ def get(self, request): assert 'XAPIView: XSerializer: unable to resolve type hint for function "get_y"' in stderr +def test_unable_to_traverse_union_type_hint(capsys): + class Foo: + foo_value: int = 1 + + class Bar: + pass + + class FailingFieldSourceTraversalModel3(models.Model): + @property + def foo_or_bar(self) -> Union[Foo, Bar]: + pass # pragma: no cover + + class XSerializer(serializers.ModelSerializer): + foo_value = serializers.ReadOnlyField(source='foo_or_bar.foo_value') + + class Meta: + model = FailingFieldSourceTraversalModel3 + fields = '__all__' + + class XAPIView(APIView): + @extend_schema(responses=XSerializer) + def get(self, request): + pass # pragma: no cover + + generate_schema('foo_value', view=XAPIView) + stderr = capsys.readouterr().err + assert 'could not traverse Union type' in stderr + assert 'Foo' in stderr + assert 'Bar' in stderr + + def test_operation_id_collision_resolution(capsys): @extend_schema(responses=OpenApiTypes.FLOAT) @api_view(['GET'])