diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 4ce9c79c3e5..0b56fa7fb67 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -8,6 +8,7 @@ import re import uuid from collections.abc import Mapping +from enum import Enum from django.conf import settings from django.core.exceptions import ObjectDoesNotExist @@ -17,7 +18,6 @@ MinValueValidator, ProhibitNullCharactersValidator, RegexValidator, URLValidator, ip_address_validators ) -from django.db.models import IntegerChoices, TextChoices from django.forms import FilePathField as DjangoFilePathField from django.forms import ImageField as DjangoImageField from django.utils import timezone @@ -1401,11 +1401,8 @@ def __init__(self, choices, **kwargs): def to_internal_value(self, data): if data == '' and self.allow_blank: return '' - - if isinstance(data, (IntegerChoices, TextChoices)) and str(data) != \ - str(data.value): + if isinstance(data, Enum) and str(data) != str(data.value): data = data.value - try: return self.choice_strings_to_values[str(data)] except KeyError: @@ -1414,11 +1411,8 @@ def to_internal_value(self, data): def to_representation(self, value): if value in ('', None): return value - - if isinstance(value, (IntegerChoices, TextChoices)) and str(value) != \ - str(value.value): + if isinstance(value, Enum) and str(value) != str(value.value): value = value.value - return self.choice_strings_to_values.get(str(value), value) def iter_options(self): @@ -1442,8 +1436,7 @@ def _set_choices(self, choices): # Allows us to deal with eg. integer choices while supporting either # integer or string input, but still get the correct datatype out. self.choice_strings_to_values = { - str(key.value) if isinstance(key, (IntegerChoices, TextChoices)) - and str(key) != str(key.value) else str(key): key for key in self.choices + str(key.value) if isinstance(key, Enum) and str(key) != str(key.value) else str(key): key for key in self.choices } choices = property(_get_choices, _set_choices) @@ -1829,6 +1822,7 @@ class HiddenField(Field): constraint on a pair of fields, as we need some way to include the date in the validated data. """ + def __init__(self, **kwargs): assert 'default' in kwargs, 'default is a required argument.' kwargs['write_only'] = True @@ -1858,6 +1852,7 @@ class ExampleSerializer(Serializer): def get_extra_info(self, obj): return ... # Calculate some data to return. """ + def __init__(self, method_name=None, **kwargs): self.method_name = method_name kwargs['source'] = '*' diff --git a/tests/test_fields.py b/tests/test_fields.py index 03584431e54..6aa21eb872f 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -16,7 +16,7 @@ pytz = None from django.core.exceptions import ValidationError as DjangoValidationError -from django.db.models import IntegerChoices, TextChoices +from django.db.models import TextChoices from django.http import QueryDict from django.test import TestCase, override_settings from django.utils.timezone import activate, deactivate, override @@ -138,6 +138,7 @@ class TestEmpty: """ Tests for `required`, `allow_null`, `allow_blank`, `default`. """ + def test_required(self): """ By default a field must be included in the input. @@ -664,6 +665,7 @@ class FieldValues: """ Base class for testing valid and invalid input values. """ + def test_valid_inputs(self, *args): """ Ensure that valid values return the expected validated data. @@ -1875,8 +1877,10 @@ def test_edit_choices(self): field.run_validation(2) assert exc_info.value.detail == ['"2" is not a valid choice.'] - def test_integer_choices(self): - class ChoiceCase(IntegerChoices): + def test_enum_choices(self): + from enum import IntEnum, auto + + class ChoiceCase(IntEnum): first = auto() second = auto() # Enum validate @@ -1884,17 +1888,15 @@ class ChoiceCase(IntegerChoices): (ChoiceCase.first, "1"), (ChoiceCase.second, "2") ] - field = serializers.ChoiceField(choices=choices) assert field.run_validation(1) == 1 assert field.run_validation(ChoiceCase.first) == 1 assert field.run_validation("1") == 1 - + # Enum.value validate choices = [ (ChoiceCase.first.value, "1"), (ChoiceCase.second.value, "2") ] - field = serializers.ChoiceField(choices=choices) assert field.run_validation(1) == 1 assert field.run_validation(ChoiceCase.first) == 1