Skip to content

Commit

Permalink
Fix choices in ChoiceField to support IntEnum
Browse files Browse the repository at this point in the history
Python support Enum in version 3.4, but changed __str__ to int.__str__ until version 3.11 to better support the replacement of existing constants use-case.
[https://docs.python.org/3/library/enum.html#enum.IntEnum](https://docs.python.org/3/library/enum.html#enum.IntEnum)

rest_frame work support Python 3.6+, this commit will support the Enum in choices of Field.
  • Loading branch information
b7wch committed Jul 6, 2023
1 parent 4f7e9ed commit a4e1038
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 17 deletions.
17 changes: 6 additions & 11 deletions rest_framework/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import uuid
from collections.abc import Mapping
from enum import Enum

from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
Expand All @@ -17,7 +18,6 @@
MinValueValidator, ProhibitNullCharactersValidator, RegexValidator,
URLValidator, ip_address_validators
)
from django.db.models import IntegerChoices, TextChoices
from django.forms import FilePathField as DjangoFilePathField
from django.forms import ImageField as DjangoImageField
from django.utils import timezone
Expand Down Expand Up @@ -1401,11 +1401,8 @@ def __init__(self, choices, **kwargs):
def to_internal_value(self, data):
if data == '' and self.allow_blank:
return ''

if isinstance(data, (IntegerChoices, TextChoices)) and str(data) != \
str(data.value):
if isinstance(data, Enum) and str(data) != str(data.value):
data = data.value

try:
return self.choice_strings_to_values[str(data)]
except KeyError:
Expand All @@ -1414,11 +1411,8 @@ def to_internal_value(self, data):
def to_representation(self, value):
if value in ('', None):
return value

if isinstance(value, (IntegerChoices, TextChoices)) and str(value) != \
str(value.value):
if isinstance(value, Enum) and str(value) != str(value.value):
value = value.value

return self.choice_strings_to_values.get(str(value), value)

def iter_options(self):
Expand All @@ -1442,8 +1436,7 @@ def _set_choices(self, choices):
# Allows us to deal with eg. integer choices while supporting either
# integer or string input, but still get the correct datatype out.
self.choice_strings_to_values = {
str(key.value) if isinstance(key, (IntegerChoices, TextChoices))
and str(key) != str(key.value) else str(key): key for key in self.choices
str(key.value) if isinstance(key, Enum) and str(key) != str(key.value) else str(key): key for key in self.choices
}

choices = property(_get_choices, _set_choices)
Expand Down Expand Up @@ -1829,6 +1822,7 @@ class HiddenField(Field):
constraint on a pair of fields, as we need some way to include the date in
the validated data.
"""

def __init__(self, **kwargs):
assert 'default' in kwargs, 'default is a required argument.'
kwargs['write_only'] = True
Expand Down Expand Up @@ -1858,6 +1852,7 @@ class ExampleSerializer(Serializer):
def get_extra_info(self, obj):
return ... # Calculate some data to return.
"""

def __init__(self, method_name=None, **kwargs):
self.method_name = method_name
kwargs['source'] = '*'
Expand Down
14 changes: 8 additions & 6 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
pytz = None

from django.core.exceptions import ValidationError as DjangoValidationError
from django.db.models import IntegerChoices, TextChoices
from django.db.models import TextChoices
from django.http import QueryDict
from django.test import TestCase, override_settings
from django.utils.timezone import activate, deactivate, override
Expand Down Expand Up @@ -138,6 +138,7 @@ class TestEmpty:
"""
Tests for `required`, `allow_null`, `allow_blank`, `default`.
"""

def test_required(self):
"""
By default a field must be included in the input.
Expand Down Expand Up @@ -664,6 +665,7 @@ class FieldValues:
"""
Base class for testing valid and invalid input values.
"""

def test_valid_inputs(self, *args):
"""
Ensure that valid values return the expected validated data.
Expand Down Expand Up @@ -1875,26 +1877,26 @@ def test_edit_choices(self):
field.run_validation(2)
assert exc_info.value.detail == ['"2" is not a valid choice.']

def test_integer_choices(self):
class ChoiceCase(IntegerChoices):
def test_enum_choices(self):
from enum import IntEnum, auto

class ChoiceCase(IntEnum):
first = auto()
second = auto()
# Enum validate
choices = [
(ChoiceCase.first, "1"),
(ChoiceCase.second, "2")
]

field = serializers.ChoiceField(choices=choices)
assert field.run_validation(1) == 1
assert field.run_validation(ChoiceCase.first) == 1
assert field.run_validation("1") == 1

# Enum.value validate
choices = [
(ChoiceCase.first.value, "1"),
(ChoiceCase.second.value, "2")
]

field = serializers.ChoiceField(choices=choices)
assert field.run_validation(1) == 1
assert field.run_validation(ChoiceCase.first) == 1
Expand Down

0 comments on commit a4e1038

Please sign in to comment.