Skip to content

Commit

Permalink
Fix choices in ChoiceField to support IntEnum
Browse files Browse the repository at this point in the history
Python support Enum in version 3.4, but changed __str__ to int.__str__ until version 3.11 to better support the replacement of existing constants use-case.
[https://docs.python.org/3/library/enum.html#enum.IntEnum](https://docs.python.org/3/library/enum.html#enum.IntEnum)

rest_frame work support Python 3.6+, this commit will support the Enum in choices of Field.
  • Loading branch information
b7wch committed Jul 13, 2023
1 parent 4f7e9ed commit 6086601
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
17 changes: 6 additions & 11 deletions rest_framework/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import uuid
from collections.abc import Mapping
from enum import Enum

from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
Expand All @@ -17,7 +18,6 @@
MinValueValidator, ProhibitNullCharactersValidator, RegexValidator,
URLValidator, ip_address_validators
)
from django.db.models import IntegerChoices, TextChoices
from django.forms import FilePathField as DjangoFilePathField
from django.forms import ImageField as DjangoImageField
from django.utils import timezone
Expand Down Expand Up @@ -1401,11 +1401,8 @@ def __init__(self, choices, **kwargs):
def to_internal_value(self, data):
if data == '' and self.allow_blank:
return ''

if isinstance(data, (IntegerChoices, TextChoices)) and str(data) != \
str(data.value):
if isinstance(data, Enum) and str(data) != str(data.value):
data = data.value

try:
return self.choice_strings_to_values[str(data)]
except KeyError:
Expand All @@ -1414,11 +1411,8 @@ def to_internal_value(self, data):
def to_representation(self, value):
if value in ('', None):
return value

if isinstance(value, (IntegerChoices, TextChoices)) and str(value) != \
str(value.value):
if isinstance(value, Enum) and str(value) != str(value.value):
value = value.value

return self.choice_strings_to_values.get(str(value), value)

def iter_options(self):
Expand All @@ -1442,8 +1436,7 @@ def _set_choices(self, choices):
# Allows us to deal with eg. integer choices while supporting either
# integer or string input, but still get the correct datatype out.
self.choice_strings_to_values = {
str(key.value) if isinstance(key, (IntegerChoices, TextChoices))
and str(key) != str(key.value) else str(key): key for key in self.choices
str(key.value) if isinstance(key, Enum) and str(key) != str(key.value) else str(key): key for key in self.choices
}

choices = property(_get_choices, _set_choices)
Expand Down Expand Up @@ -1829,6 +1822,7 @@ class HiddenField(Field):
constraint on a pair of fields, as we need some way to include the date in
the validated data.
"""

def __init__(self, **kwargs):
assert 'default' in kwargs, 'default is a required argument.'
kwargs['write_only'] = True
Expand Down Expand Up @@ -1858,6 +1852,7 @@ class ExampleSerializer(Serializer):
def get_extra_info(self, obj):
return ... # Calculate some data to return.
"""

def __init__(self, method_name=None, **kwargs):
self.method_name = method_name
kwargs['source'] = '*'
Expand Down
27 changes: 27 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -1875,6 +1875,33 @@ def test_edit_choices(self):
field.run_validation(2)
assert exc_info.value.detail == ['"2" is not a valid choice.']

def test_enum_integer_choices(self):
from enum import IntEnum

class ChoiceCase(IntEnum):
first = auto()
second = auto()
# Enum validate
choices = [
(ChoiceCase.first, "1"),
(ChoiceCase.second, "2")
]

field = serializers.ChoiceField(choices=choices)
assert field.run_validation(1) == 1
assert field.run_validation(ChoiceCase.first) == 1
assert field.run_validation("1") == 1
# Enum.value validate
choices = [
(ChoiceCase.first.value, "1"),
(ChoiceCase.second.value, "2")
]

field = serializers.ChoiceField(choices=choices)
assert field.run_validation(1) == 1
assert field.run_validation(ChoiceCase.first) == 1
assert field.run_validation("1") == 1

def test_integer_choices(self):
class ChoiceCase(IntegerChoices):
first = auto()
Expand Down

0 comments on commit 6086601

Please sign in to comment.