diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 1d77ce4abc..f4fd265394 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -16,6 +16,7 @@ MinValueValidator, ProhibitNullCharactersValidator, RegexValidator, URLValidator, ip_address_validators ) +from django.db.models import IntegerChoices, TextChoices from django.forms import FilePathField as DjangoFilePathField from django.forms import ImageField as DjangoImageField from django.utils import timezone @@ -1397,6 +1398,10 @@ def to_internal_value(self, data): if data == '' and self.allow_blank: return '' + if isinstance(data, (IntegerChoices, TextChoices)) and str(data) != \ + str(data.value): + data = data.value + try: return self.choice_strings_to_values[str(data)] except KeyError: @@ -1405,6 +1410,11 @@ def to_internal_value(self, data): def to_representation(self, value): if value in ('', None): return value + + if isinstance(value, (IntegerChoices, TextChoices)) and str(value) != \ + str(value.value): + value = value.value + return self.choice_strings_to_values.get(str(value), value) def iter_options(self): @@ -1428,7 +1438,8 @@ def _set_choices(self, choices): # Allows us to deal with eg. integer choices while supporting either # integer or string input, but still get the correct datatype out. self.choice_strings_to_values = { - str(key): key for key in self.choices + str(key.value) if isinstance(key, (IntegerChoices, TextChoices)) + and str(key) != str(key.value) else str(key): key for key in self.choices } choices = property(_get_choices, _set_choices) diff --git a/tests/test_fields.py b/tests/test_fields.py index 5804d7b3b3..bf25b71b8d 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -5,11 +5,13 @@ import sys import uuid from decimal import ROUND_DOWN, ROUND_UP, Decimal +from enum import auto from unittest.mock import patch import pytest import pytz from django.core.exceptions import ValidationError as DjangoValidationError +from django.db.models import IntegerChoices, TextChoices from django.http import QueryDict from django.test import TestCase, override_settings from django.utils.timezone import activate, deactivate, override @@ -1824,6 +1826,54 @@ def test_edit_choices(self): field.run_validation(2) assert exc_info.value.detail == ['"2" is not a valid choice.'] + def test_integer_choices(self): + class ChoiceCase(IntegerChoices): + first = auto() + second = auto() + # Enum validate + choices = [ + (ChoiceCase.first, "1"), + (ChoiceCase.second, "2") + ] + + field = serializers.ChoiceField(choices=choices) + assert field.run_validation(1) == 1 + assert field.run_validation(ChoiceCase.first) == 1 + assert field.run_validation("1") == 1 + + choices = [ + (ChoiceCase.first.value, "1"), + (ChoiceCase.second.value, "2") + ] + + field = serializers.ChoiceField(choices=choices) + assert field.run_validation(1) == 1 + assert field.run_validation(ChoiceCase.first) == 1 + assert field.run_validation("1") == 1 + + def test_text_choices(self): + class ChoiceCase(TextChoices): + first = auto() + second = auto() + # Enum validate + choices = [ + (ChoiceCase.first, "first"), + (ChoiceCase.second, "second") + ] + + field = serializers.ChoiceField(choices=choices) + assert field.run_validation(ChoiceCase.first) == "first" + assert field.run_validation("first") == "first" + + choices = [ + (ChoiceCase.first.value, "first"), + (ChoiceCase.second.value, "second") + ] + + field = serializers.ChoiceField(choices=choices) + assert field.run_validation(ChoiceCase.first) == "first" + assert field.run_validation("first") == "first" + class TestChoiceFieldWithType(FieldValues): """