diff --git a/HISTORY.rst b/HISTORY.rst index d71540d6..1ce31e78 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -5,6 +5,7 @@ History ------------------- * Add PEP 563 (string annotations) for dataclasses. (`#195 `_) +* Fix handling of dictionaries with string Enum keys for bson, orjson, and tomlkit. 1.9.0 (2021-12-06) ------------------ diff --git a/src/cattr/preconf/bson.py b/src/cattr/preconf/bson.py index e05ad681..994b6714 100644 --- a/src/cattr/preconf/bson.py +++ b/src/cattr/preconf/bson.py @@ -12,12 +12,17 @@ def configure_converter(converter: GenConverter): Configure the converter for use with the bson library. * sets are serialized as lists - * mapping keys are coerced into strings when unstructuring + * non-string mapping keys are coerced into strings when unstructuring """ def gen_unstructure_mapping(cl: Any, unstructure_to=None): + key_handler = str + if (args := getattr(cl, "__args__", None)) and issubclass( + args[0], str + ): + key_handler = None return converter.gen_unstructure_mapping( - cl, unstructure_to=unstructure_to, key_handler=str + cl, unstructure_to=unstructure_to, key_handler=key_handler ) converter._unstructure_func.register_func_list( diff --git a/src/cattr/preconf/orjson.py b/src/cattr/preconf/orjson.py index 174b27a1..1c39ae95 100644 --- a/src/cattr/preconf/orjson.py +++ b/src/cattr/preconf/orjson.py @@ -1,6 +1,7 @@ """Preconfigured converters for orjson.""" from base64 import b85decode, b85encode from datetime import datetime +from enum import Enum from typing import Any from .._compat import Set, is_mapping @@ -14,6 +15,7 @@ def configure_converter(converter: GenConverter): * bytes are serialized as base85 strings * datetimes are serialized as ISO 8601 * sets are serialized as lists + * string enum mapping keys have special handling * mapping keys are coerced into strings when unstructuring """ converter.register_unstructure_hook( @@ -27,8 +29,18 @@ def configure_converter(converter: GenConverter): ) def gen_unstructure_mapping(cl: Any, unstructure_to=None): + key_handler = str + if ( + (args := getattr(cl, "__args__", None)) + and issubclass(args[0], str) + and issubclass(args[0], Enum) + ): + + def key_handler(v): + return v.value + return converter.gen_unstructure_mapping( - cl, unstructure_to=unstructure_to, key_handler=str + cl, unstructure_to=unstructure_to, key_handler=key_handler ) converter._unstructure_func.register_func_list( diff --git a/src/cattr/preconf/tomlkit.py b/src/cattr/preconf/tomlkit.py index 6e428469..cdc652d7 100644 --- a/src/cattr/preconf/tomlkit.py +++ b/src/cattr/preconf/tomlkit.py @@ -23,8 +23,13 @@ def configure_converter(converter: GenConverter): ) def gen_unstructure_mapping(cl: Any, unstructure_to=None): + key_handler = str + if (args := getattr(cl, "__args__", None)) and issubclass( + args[0], str + ): + key_handler = None return converter.gen_unstructure_mapping( - cl, unstructure_to=unstructure_to, key_handler=str + cl, unstructure_to=unstructure_to, key_handler=key_handler ) converter._unstructure_func.register_func_list( diff --git a/tests/test_preconf.py b/tests/test_preconf.py index 70241e69..6d9b448d 100644 --- a/tests/test_preconf.py +++ b/tests/test_preconf.py @@ -71,6 +71,7 @@ class AStringEnum(str, Enum): an_int_enum: AnIntEnum a_str_enum: AStringEnum a_datetime: datetime + a_string_enum_dict: Dict[AStringEnum, int] @composite @@ -161,6 +162,12 @@ def everythings( Everything.AnIntEnum.A, Everything.AStringEnum.A, draw(dts), + draw( + dictionaries( + just(Everything.AStringEnum.A), + integers(min_value=min_int, max_value=max_int), + ) + ), ) @@ -204,7 +211,7 @@ def test_orjson(everything: Everything): from orjson import loads as orjson_loads converter = orjson_make_converter() - raw = orjson_dumps(converter.unstructure(everything)) + raw = orjson_dumps(r := converter.unstructure(everything)) assert converter.structure(orjson_loads(raw), Everything) == everything