diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 4ce9c79c3e5..0b56fa7fb67 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -8,6 +8,7 @@ import re import uuid from collections.abc import Mapping +from enum import Enum from django.conf import settings from django.core.exceptions import ObjectDoesNotExist @@ -17,7 +18,6 @@ MinValueValidator, ProhibitNullCharactersValidator, RegexValidator, URLValidator, ip_address_validators ) -from django.db.models import IntegerChoices, TextChoices from django.forms import FilePathField as DjangoFilePathField from django.forms import ImageField as DjangoImageField from django.utils import timezone @@ -1401,11 +1401,8 @@ def __init__(self, choices, **kwargs): def to_internal_value(self, data): if data == '' and self.allow_blank: return '' - - if isinstance(data, (IntegerChoices, TextChoices)) and str(data) != \ - str(data.value): + if isinstance(data, Enum) and str(data) != str(data.value): data = data.value - try: return self.choice_strings_to_values[str(data)] except KeyError: @@ -1414,11 +1411,8 @@ def to_internal_value(self, data): def to_representation(self, value): if value in ('', None): return value - - if isinstance(value, (IntegerChoices, TextChoices)) and str(value) != \ - str(value.value): + if isinstance(value, Enum) and str(value) != str(value.value): value = value.value - return self.choice_strings_to_values.get(str(value), value) def iter_options(self): @@ -1442,8 +1436,7 @@ def _set_choices(self, choices): # Allows us to deal with eg. integer choices while supporting either # integer or string input, but still get the correct datatype out. self.choice_strings_to_values = { - str(key.value) if isinstance(key, (IntegerChoices, TextChoices)) - and str(key) != str(key.value) else str(key): key for key in self.choices + str(key.value) if isinstance(key, Enum) and str(key) != str(key.value) else str(key): key for key in self.choices } choices = property(_get_choices, _set_choices) @@ -1829,6 +1822,7 @@ class HiddenField(Field): constraint on a pair of fields, as we need some way to include the date in the validated data. """ + def __init__(self, **kwargs): assert 'default' in kwargs, 'default is a required argument.' kwargs['write_only'] = True @@ -1858,6 +1852,7 @@ class ExampleSerializer(Serializer): def get_extra_info(self, obj): return ... # Calculate some data to return. """ + def __init__(self, method_name=None, **kwargs): self.method_name = method_name kwargs['source'] = '*' diff --git a/tests/test_fields.py b/tests/test_fields.py index 03584431e54..52383671519 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1875,6 +1875,33 @@ def test_edit_choices(self): field.run_validation(2) assert exc_info.value.detail == ['"2" is not a valid choice.'] + def test_enum_integer_choices(self): + from enum import IntEnum + + class ChoiceCase(IntEnum): + first = auto() + second = auto() + # Enum validate + choices = [ + (ChoiceCase.first, "1"), + (ChoiceCase.second, "2") + ] + + field = serializers.ChoiceField(choices=choices) + assert field.run_validation(1) == 1 + assert field.run_validation(ChoiceCase.first) == 1 + assert field.run_validation("1") == 1 + # Enum.value validate + choices = [ + (ChoiceCase.first.value, "1"), + (ChoiceCase.second.value, "2") + ] + + field = serializers.ChoiceField(choices=choices) + assert field.run_validation(1) == 1 + assert field.run_validation(ChoiceCase.first) == 1 + assert field.run_validation("1") == 1 + def test_integer_choices(self): class ChoiceCase(IntegerChoices): first = auto()