Skip to content

Commit

Permalink
account for functools.partial wrapped type hints #451
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Jul 15, 2021
1 parent cd8f73c commit eb02752
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 13 deletions.
8 changes: 3 additions & 5 deletions drf_spectacular/contrib/django_filters.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import typing

from django.db import models

from drf_spectacular.drainage import warn
from drf_spectacular.extensions import OpenApiFilterExtension
from drf_spectacular.plumbing import (
build_array_type, build_basic_type, build_parameter_type, follow_field_source, get_view_model,
is_basic_type,
build_array_type, build_basic_type, build_parameter_type, follow_field_source, get_type_hints,
get_view_model, is_basic_type,
)
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter
Expand Down Expand Up @@ -167,7 +165,7 @@ def _build_filter_method_type(self, filterset_class, filter_field):
filter_method = getattr(filterset_class, filter_field.method)

try:
filter_method_hints = typing.get_type_hints(filter_method)
filter_method_hints = get_type_hints(filter_method)
except: # noqa: E722
filter_method_hints = {}

Expand Down
5 changes: 5 additions & 0 deletions drf_spectacular/drainage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import functools
import sys
from collections import defaultdict
from typing import DefaultDict
Expand Down Expand Up @@ -70,6 +71,8 @@ def _get_current_trace():


def has_override(obj, prop):
if isinstance(obj, functools.partial):
obj = obj.func
if not hasattr(obj, '_spectacular_annotation'):
return False
if prop not in obj._spectacular_annotation:
Expand All @@ -78,6 +81,8 @@ def has_override(obj, prop):


def get_override(obj, prop, default=None):
if isinstance(obj, functools.partial):
obj = obj.func
if not has_override(obj, prop):
return default
return obj._spectacular_annotation[prop]
Expand Down
8 changes: 4 additions & 4 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
ComponentRegistry, ResolvedComponent, UnableToProceedError, append_meta, build_array_type,
build_basic_type, build_choice_field, build_examples_list, build_media_type_object,
build_object_type, build_parameter_type, error, follow_field_source, force_instance, get_doc,
get_view_model, is_basic_type, is_field, is_list_serializer, is_patched_serializer,
is_serializer, is_trivial_string_variation, resolve_regex_path_parameter, resolve_type_hint,
safe_ref, warn,
get_type_hints, get_view_model, is_basic_type, is_field, is_list_serializer,
is_patched_serializer, is_serializer, is_trivial_string_variation, resolve_regex_path_parameter,
resolve_type_hint, safe_ref, warn,
)
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes, build_generic_type
Expand Down Expand Up @@ -828,7 +828,7 @@ def _map_field_validators(self, field, schema):
schema['minimum'] = -schema['maximum']

def _map_response_type_hint(self, method):
hint = get_override(method, 'field') or typing.get_type_hints(method).get('return')
hint = get_override(method, 'field') or get_type_hints(method).get('return')

if is_serializer(hint) or is_field(hint):
return self._map_serializer_field(force_instance(hint), 'response')
Expand Down
16 changes: 12 additions & 4 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
import functools
import hashlib
import inspect
import json
Expand Down Expand Up @@ -164,6 +165,13 @@ def safe_index(lst, item):
return ''


def get_type_hints(obj):
""" unpack wrapped partial object and use actual func object """
if isinstance(obj, functools.partial):
obj = obj.func
return typing.get_type_hints(obj)


def build_basic_type(obj):
"""
resolve either enum or actual type and yield schema template for modification
Expand Down Expand Up @@ -409,7 +417,7 @@ def _follow_field_source(model, path: List[str]):


def _follow_return_type(a_callable):
target_type = typing.get_type_hints(a_callable).get('return')
target_type = get_type_hints(a_callable).get('return')
if target_type is None:
return target_type
origin, args = _get_type_hint_origin(target_type)
Expand Down Expand Up @@ -929,8 +937,8 @@ def resolve_type_hint(hint):
return build_basic_type(hint)
elif origin is None and inspect.isclass(hint) and issubclass(hint, tuple):
# a convoluted way to catch NamedTuple. suggestions welcome.
if typing.get_type_hints(hint):
properties = {k: resolve_type_hint(v) for k, v in typing.get_type_hints(hint).items()}
if get_type_hints(hint):
properties = {k: resolve_type_hint(v) for k, v in get_type_hints(hint).items()}
else:
properties = {k: build_basic_type(OpenApiTypes.ANY) for k in hint._fields}
return build_object_type(properties=properties, required=properties.keys())
Expand Down Expand Up @@ -962,7 +970,7 @@ def resolve_type_hint(hint):
elif hasattr(typing, 'TypedDict') and isinstance(hint, typing._TypedDictMeta):
return build_object_type(
properties={
k: resolve_type_hint(v) for k, v in typing.get_type_hints(hint).items()
k: resolve_type_hint(v) for k, v in get_type_hints(hint).items()
}
)
elif origin is typing.Union:
Expand Down
28 changes: 28 additions & 0 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import typing
import uuid
from decimal import Decimal
from functools import partialmethod
from unittest import mock

import pytest
Expand Down Expand Up @@ -2024,3 +2025,30 @@ def get(self, request):
assert get_response_schema(schema['paths']['/x/']['get'])['properties']['results'] == {
'type': 'array', 'items': {'type': 'object', 'additionalProperties': {}}
}


def test_serializer_method_field_with_functools_partial():
class XSerializer(serializers.Serializer):
foo = serializers.SerializerMethodField()
bar = serializers.SerializerMethodField()

@extend_schema_field(OpenApiTypes.DATE)
def _private_method_foo(self, field, extra_param):
return 'foo' # pragma: no cover

def _private_method_bar(self, field, extra_param) -> int:
return 1 # pragma: no cover

get_foo = partialmethod(_private_method_foo, extra_param='foo')
get_bar = partialmethod(_private_method_bar, extra_param='bar')

@extend_schema(request=XSerializer, responses=XSerializer)
@api_view(['POST'])
def view_func(request, format=None):
pass # pragma: no cover

schema = generate_schema('/x/', view_function=view_func)
assert schema['components']['schemas']['X']['properties'] == {
'foo': {'type': 'string', 'format': 'date', 'readOnly': True},
'bar': {'type': 'integer', 'readOnly': True}
}

0 comments on commit eb02752

Please sign in to comment.