From b8926b04f4072b14a298b6f472ea390b4e007114 Mon Sep 17 00:00:00 2001 From: b7wsh Date: Thu, 20 Apr 2023 11:29:32 +0800 Subject: [PATCH] Fix choices in ChoiceField to support IntEnum Python support Enum in version 3.4, but changed __str__ to int.__str__ until version 3.11 to better support the replacement of existing constants use-case. [https://docs.python.org/3/library/enum.html#enum.IntEnum](https://docs.python.org/3/library/enum.html#enum.IntEnum) rest_frame work support Python 3.6+, this commit will support the Enum in choices of Field. --- rest_framework/fields.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index e41b56fb01e..9b40bb62737 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -8,6 +8,7 @@ import uuid from collections import OrderedDict from collections.abc import Mapping +from enum import Enum from django.conf import settings from django.core.exceptions import ObjectDoesNotExist @@ -1397,7 +1398,8 @@ def __init__(self, choices, **kwargs): def to_internal_value(self, data): if data == '' and self.allow_blank: return '' - + if isinstance(data, Enum) and str(data) != str(data.value): + data = data.value try: return self.choice_strings_to_values[str(data)] except KeyError: @@ -1406,6 +1408,8 @@ def to_internal_value(self, data): def to_representation(self, value): if value in ('', None): return value + if isinstance(value, Enum) and str(value) != str(value.value): + value = value.value return self.choice_strings_to_values.get(str(value), value) def iter_options(self): @@ -1429,7 +1433,7 @@ def _set_choices(self, choices): # Allows us to deal with eg. integer choices while supporting either # integer or string input, but still get the correct datatype out. self.choice_strings_to_values = { - str(key): key for key in self.choices + str(key.value) if isinstance(key, Enum) and str(key) != str(key.value) else str(key): key for key in self.choices } choices = property(_get_choices, _set_choices) @@ -1815,6 +1819,7 @@ class HiddenField(Field): constraint on a pair of fields, as we need some way to include the date in the validated data. """ + def __init__(self, **kwargs): assert 'default' in kwargs, 'default is a required argument.' kwargs['write_only'] = True @@ -1844,6 +1849,7 @@ class ExampleSerializer(Serializer): def get_extra_info(self, obj): return ... # Calculate some data to return. """ + def __init__(self, method_name=None, **kwargs): self.method_name = method_name kwargs['source'] = '*'