diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 4f82e4a10e..623e72e0ad 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -113,27 +113,6 @@ def get_attribute(instance, attrs): return instance -def set_value(dictionary, keys, value): - """ - Similar to Python's built in `dictionary[key] = value`, - but takes a list of nested keys instead of a single key. - - set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2} - set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2} - set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}} - """ - if not keys: - dictionary.update(value) - return - - for key in keys[:-1]: - if key not in dictionary: - dictionary[key] = {} - dictionary = dictionary[key] - - dictionary[keys[-1]] = value - - def to_choices_dict(choices): """ Convert choices into key/value dicts. diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 01bebf5fca..a3d68b03db 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -28,7 +28,7 @@ from rest_framework.compat import postgres_fields from rest_framework.exceptions import ErrorDetail, ValidationError -from rest_framework.fields import get_error_detail, set_value +from rest_framework.fields import get_error_detail from rest_framework.settings import api_settings from rest_framework.utils import html, model_meta, representation from rest_framework.utils.field_mapping import ( @@ -346,6 +346,26 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): 'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.') } + def set_value(self, dictionary, keys, value): + """ + Similar to Python's built in `dictionary[key] = value`, + but takes a list of nested keys instead of a single key. + + set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2} + set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2} + set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}} + """ + if not keys: + dictionary.update(value) + return + + for key in keys[:-1]: + if key not in dictionary: + dictionary[key] = {} + dictionary = dictionary[key] + + dictionary[keys[-1]] = value + @cached_property def fields(self): """ @@ -492,7 +512,7 @@ def to_internal_value(self, data): except SkipField: pass else: - set_value(ret, field.source_attrs, validated_value) + self.set_value(ret, field.source_attrs, validated_value) if errors: raise ValidationError(errors) diff --git a/tests/test_serializer.py b/tests/test_serializer.py index 1d9efaa434..10fa8afb94 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -762,3 +762,24 @@ class TestSerializer(serializers.Serializer): assert (s.data | {}).__class__ == s.data.__class__ assert ({} | s.data).__class__ == s.data.__class__ + + +class TestSetValueMethod: + # Serializer.set_value() modifies the first parameter in-place. + + s = serializers.Serializer() + + def test_no_keys(self): + ret = {'a': 1} + self.s.set_value(ret, [], {'b': 2}) + assert ret == {'a': 1, 'b': 2} + + def test_one_key(self): + ret = {'a': 1} + self.s.set_value(ret, ['x'], 2) + assert ret == {'a': 1, 'x': 2} + + def test_nested_key(self): + ret = {'a': 1} + self.s.set_value(ret, ['x', 'y'], 2) + assert ret == {'a': 1, 'x': {'y': 2}}