Skip to content

Commit

Permalink
Merge pull request #350 from spookylukey/fix_optional_as_nullable
Browse files Browse the repository at this point in the history
Fixed traversing of 'Optional' type annotations
  • Loading branch information
tfranzel authored Mar 30, 2021
2 parents 2a9cf74 + 5b17e65 commit 7ac1814
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 3 deletions.
25 changes: 22 additions & 3 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,11 +391,11 @@ def _follow_field_source(model, path: List[str]):
else:
if isinstance(field_or_property, (property, cached_property)) or callable(field_or_property):
if isinstance(field_or_property, property):
target_model = typing.get_type_hints(field_or_property.fget).get('return')
target_model = _follow_return_type(field_or_property.fget)
elif isinstance(field_or_property, cached_property):
target_model = typing.get_type_hints(field_or_property.func).get('return')
target_model = _follow_return_type(field_or_property.func)
else:
target_model = typing.get_type_hints(field_or_property).get('return')
target_model = _follow_return_type(field_or_property)
if not target_model:
raise UnableToProceedError(
f'could not follow field source through intermediate property "{path[0]}" '
Expand All @@ -408,6 +408,25 @@ def _follow_field_source(model, path: List[str]):
return _follow_field_source(target_model, path[1:])


def _follow_return_type(a_callable):
target_type = typing.get_type_hints(a_callable).get('return')
if target_type is None:
return target_type
origin, args = _get_type_hint_origin(target_type)
if origin is typing.Union:
type_args = [arg for arg in args if arg is not type(None)] # noqa: E721
if len(type_args) > 1:
warn(
f'could not traverse Union type, because we don\'t know which type to choose '
f'from {type_args}. Consider terminating "source" on a custom property '
f'that indicates the expected Optional/Union type. Defaulting to "string"'
)
return target_type
# Optional:
return type_args[0]
return target_type


def follow_field_source(model, path):
"""
a model traversal chain "foreignkey.foreignkey.value" can either end with an actual model field
Expand Down
19 changes: 19 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import uuid
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Optional

import pytest
from django import __version__ as DJANGO_VERSION
Expand Down Expand Up @@ -48,6 +49,10 @@ def nested(self) -> 'SubObject':
def model_instance(self) -> 'AllFields':
return self._instance

@property
def optional_int(self) -> Optional[int]:
return 1


class AllFields(models.Model):
# basics
Expand Down Expand Up @@ -127,6 +132,10 @@ def sub_object(self) -> SubObject:
def sub_object_cached(self) -> SubObject:
return SubObject(self)

@property
def optional_sub_object(self) -> Optional[SubObject]:
return SubObject(self)


class AllFieldsSerializer(serializers.ModelSerializer):
field_decimal_uncoerced = serializers.DecimalField(
Expand Down Expand Up @@ -204,6 +213,16 @@ def get_field_method_object(self, obj) -> dict:
field_sub_object_cached_nested_calculated = serializers.ReadOnlyField(source='sub_object_cached.nested.calculated')
field_sub_object_cached_model_int = serializers.ReadOnlyField(source='sub_object_cached.model_instance.field_int')

# typing.Optional
field_optional_sub_object_calculated = serializers.ReadOnlyField(
source='optional_sub_object.calculated',
allow_null=True,
)
field_sub_object_optional_int = serializers.ReadOnlyField(
source='sub_object.optional_int',
allow_null=True,
)

class Meta:
fields = '__all__'
model = AllFields
Expand Down
10 changes: 10 additions & 0 deletions tests/test_fields.yml
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,14 @@ components:
field_sub_object_cached_model_int:
type: integer
readOnly: true
field_optional_sub_object_calculated:
type: integer
readOnly: true
nullable: true
field_sub_object_optional_int:
type: integer
nullable: true
readOnly: true
field_int:
type: integer
field_float:
Expand Down Expand Up @@ -300,6 +308,7 @@ components:
- field_model_cached_property_float
- field_model_property_float
- field_o2o
- field_optional_sub_object_calculated
- field_posint
- field_possmallint
- field_read_only_model_function_basic
Expand All @@ -318,6 +327,7 @@ components:
- field_sub_object_calculated
- field_sub_object_model_int
- field_sub_object_nested_calculated
- field_sub_object_optional_int
- field_text
- field_time
- field_url
Expand Down
32 changes: 32 additions & 0 deletions tests/test_warnings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Union
from unittest import mock

import pytest
Expand Down Expand Up @@ -209,6 +210,37 @@ def get(self, request):
assert 'XAPIView: XSerializer: unable to resolve type hint for function "get_y"' in stderr


def test_unable_to_traverse_union_type_hint(capsys):
class Foo:
foo_value: int = 1

class Bar:
pass

class FailingFieldSourceTraversalModel3(models.Model):
@property
def foo_or_bar(self) -> Union[Foo, Bar]:
pass # pragma: no cover

class XSerializer(serializers.ModelSerializer):
foo_value = serializers.ReadOnlyField(source='foo_or_bar.foo_value')

class Meta:
model = FailingFieldSourceTraversalModel3
fields = '__all__'

class XAPIView(APIView):
@extend_schema(responses=XSerializer)
def get(self, request):
pass # pragma: no cover

generate_schema('foo_value', view=XAPIView)
stderr = capsys.readouterr().err
assert 'could not traverse Union type' in stderr
assert 'Foo' in stderr
assert 'Bar' in stderr


def test_operation_id_collision_resolution(capsys):
@extend_schema(responses=OpenApiTypes.FLOAT)
@api_view(['GET'])
Expand Down

0 comments on commit 7ac1814

Please sign in to comment.