Skip to content

Commit

Permalink
Merge pull request #689 from glennmatthews/gfm-issue-688
Browse files Browse the repository at this point in the history
Avoid a TypeError when ChoiceFilter choices are a callable
  • Loading branch information
tfranzel authored Mar 23, 2022
2 parents bc507cd + acf26a4 commit 8ff0d6b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
6 changes: 5 additions & 1 deletion drf_spectacular/contrib/django_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,11 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
# enrich schema with additional info from filter_field
enum = schema.pop('enum', None)
if 'choices' in filter_field.extra:
enum = [c for c, _ in filter_field.extra['choices']]
if callable(filter_field.extra['choices']):
# choices function may utilize the DB, so refrain from actually calling it.
enum = None
else:
enum = [c for c, _ in filter_field.extra['choices']]
if enum:
schema['enum'] = sorted(enum, key=str)

Expand Down
12 changes: 10 additions & 2 deletions tests/contrib/test_django_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

try:
from django_filters.rest_framework import (
AllValuesFilter, BaseInFilter, BooleanFilter, CharFilter, DjangoFilterBackend, FilterSet,
ModelChoiceFilter, ModelMultipleChoiceFilter, MultipleChoiceFilter, NumberFilter,
AllValuesFilter, BaseInFilter, BooleanFilter, CharFilter, ChoiceFilter, DjangoFilterBackend,
FilterSet, ModelChoiceFilter, ModelMultipleChoiceFilter, MultipleChoiceFilter, NumberFilter,
NumericRangeFilter, OrderingFilter, RangeFilter, UUIDFilter,
)
except ImportError:
Expand All @@ -30,6 +30,7 @@ def init(self, **kwargs):
pass

CharFilter = NumberFilter
ChoiceFilter = NumberFilter
OrderingFilter = NumberFilter
BaseInFilter = NumberFilter
BooleanFilter = NumberFilter
Expand Down Expand Up @@ -117,6 +118,10 @@ class ProductFilter(FilterSet):
RangeFilter(field_name='price_vat')
)

def get_choices(*args, **kwargs):
return (('A', 'aaa'),)
cat_callable = ChoiceFilter(field_name="category", choices=get_choices)

class Meta:
model = Product
fields = [
Expand Down Expand Up @@ -218,6 +223,9 @@ def test_django_filters_requests(no_warnings):
response = APIClient().get('/api/products/?multi_cat=A&multi_cat=B')
assert response.status_code == 200, response.content
assert len(response.json()) == 1
response = APIClient().get('/api/products/?cat_callable=A')
assert response.status_code == 200, response.content
assert len(response.json()) == 1


@pytest.mark.contrib('django_filter')
Expand Down
5 changes: 5 additions & 0 deletions tests/contrib/test_django_filters.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ paths:
schema:
type: number
format: double
- in: query
name: cat_callable
schema:
type: string
description: some category description
- in: query
name: category
schema:
Expand Down

0 comments on commit 8ff0d6b

Please sign in to comment.