diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 3b82f3e9..b6090d51 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -4,6 +4,7 @@ import json import struct import sys +import warnings from abc import ABC from base64 import b64decode, b64encode from datetime import datetime, timedelta, timezone @@ -21,6 +22,8 @@ get_type_hints, ) +import typing + from ._types import T from .casing import camel_case, safe_snake_case, snake_case from .grpc.grpclib_client import ServiceStub @@ -251,7 +254,7 @@ def map_field( ) -class Enum(int, enum.Enum): +class Enum(enum.IntEnum): """Protocol buffers enumeration base class. Acts like `enum.IntEnum`.""" @classmethod @@ -635,9 +638,13 @@ def __bytes__(self) -> bytes: @classmethod def _type_hint(cls, field_name: str) -> Type: + return cls._type_hints()[field_name] + + @classmethod + def _type_hints(cls) -> Dict[str, Type]: module = inspect.getmodule(cls) type_hints = get_type_hints(cls, vars(module)) - return type_hints[field_name] + return type_hints @classmethod def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type: @@ -789,55 +796,67 @@ def to_dict( `False`. """ output: Dict[str, Any] = {} + field_types = self._type_hints() for field_name, meta in self._betterproto.meta_by_field_name.items(): - v = getattr(self, field_name) + field_type = field_types[field_name] + field_is_repeated = type(field_type) is type(typing.List) + value = getattr(self, field_name) cased_name = casing(field_name).rstrip("_") # type: ignore - if meta.proto_type == "message": - if isinstance(v, datetime): - if v != DATETIME_ZERO or include_default_values: - output[cased_name] = _Timestamp.timestamp_to_json(v) - elif isinstance(v, timedelta): - if v != timedelta(0) or include_default_values: - output[cased_name] = _Duration.delta_to_json(v) + if meta.proto_type == TYPE_MESSAGE: + if isinstance(value, datetime): + if value != DATETIME_ZERO or include_default_values: + output[cased_name] = _Timestamp.timestamp_to_json(value) + elif isinstance(value, timedelta): + if value != timedelta(0) or include_default_values: + output[cased_name] = _Duration.delta_to_json(value) elif meta.wraps: - if v is not None or include_default_values: - output[cased_name] = v - elif isinstance(v, list): + if value is not None or include_default_values: + output[cased_name] = value + elif field_is_repeated: # Convert each item. - v = [i.to_dict(casing, include_default_values) for i in v] - if v or include_default_values: - output[cased_name] = v + value = [i.to_dict(casing, include_default_values) for i in value] + if value or include_default_values: + output[cased_name] = value else: - if v._serialized_on_wire or include_default_values: - output[cased_name] = v.to_dict(casing, include_default_values) - elif meta.proto_type == "map": - for k in v: - if hasattr(v[k], "to_dict"): - v[k] = v[k].to_dict(casing, include_default_values) - - if v or include_default_values: - output[cased_name] = v - elif v != self._get_field_default(field_name) or include_default_values: + if value._serialized_on_wire or include_default_values: + output[cased_name] = value.to_dict( + casing, include_default_values + ) + elif meta.proto_type == TYPE_MAP: + for k in value: + if hasattr(value[k], "to_dict"): + value[k] = value[k].to_dict(casing, include_default_values) + + if value or include_default_values: + output[cased_name] = value + elif value != self._get_field_default(field_name) or include_default_values: if meta.proto_type in INT_64_TYPES: - if isinstance(v, list): - output[cased_name] = [str(n) for n in v] + if field_is_repeated: + output[cased_name] = [str(n) for n in value] else: - output[cased_name] = str(v) + output[cased_name] = str(value) elif meta.proto_type == TYPE_BYTES: - if isinstance(v, list): - output[cased_name] = [b64encode(b).decode("utf8") for b in v] + if field_is_repeated: + output[cased_name] = [ + b64encode(b).decode("utf8") for b in value + ] else: - output[cased_name] = b64encode(v).decode("utf8") + output[cased_name] = b64encode(value).decode("utf8") elif meta.proto_type == TYPE_ENUM: - enum_values = list( - self._betterproto.cls_by_field[field_name] - ) # type: ignore - if isinstance(v, list): - output[cased_name] = [enum_values[e].name for e in v] + if field_is_repeated: + enum_class: Type[Enum] = field_type.__args__[0] + if isinstance(value, typing.Iterable) and not isinstance( + value, str + ): + output[cased_name] = [enum_class(el).name for el in value] + else: + # transparently upgrade single value to repeated + output[cased_name] = [enum_class(value).name] else: - output[cased_name] = enum_values[v].name + enum_class: Type[Enum] = field_type # noqa + output[cased_name] = enum_class(value).name else: - output[cased_name] = v + output[cased_name] = value return output def from_dict(self: T, value: dict) -> T: diff --git a/tests/inputs/config.py b/tests/inputs/config.py index 38e9603f..7d146673 100644 --- a/tests/inputs/config.py +++ b/tests/inputs/config.py @@ -5,7 +5,7 @@ "namespace_keywords", # 70 "namespace_builtin_types", # 53 "googletypes_struct", # 9 - "googletypes_value", # 9, + "googletypes_value", # 9 "import_capitalized_package", "example", # This is the example in the readme. Not a test. } diff --git a/tests/inputs/enum/enum.json b/tests/inputs/enum/enum.json new file mode 100644 index 00000000..d68f1c50 --- /dev/null +++ b/tests/inputs/enum/enum.json @@ -0,0 +1,9 @@ +{ + "choice": "FOUR", + "choices": [ + "ZERO", + "ONE", + "THREE", + "FOUR" + ] +} diff --git a/tests/inputs/enum/enum.proto b/tests/inputs/enum/enum.proto new file mode 100644 index 00000000..a2dfe437 --- /dev/null +++ b/tests/inputs/enum/enum.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +// Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values +message Test { + Choice choice = 1; + repeated Choice choices = 2; +} + +enum Choice { + ZERO = 0; + ONE = 1; + // TWO = 2; + FOUR = 4; + THREE = 3; +} diff --git a/tests/inputs/enum/test_enum.py b/tests/inputs/enum/test_enum.py new file mode 100644 index 00000000..3005c43a --- /dev/null +++ b/tests/inputs/enum/test_enum.py @@ -0,0 +1,84 @@ +from tests.output_betterproto.enum import ( + Test, + Choice, +) + + +def test_enum_set_and_get(): + assert Test(choice=Choice.ZERO).choice == Choice.ZERO + assert Test(choice=Choice.ONE).choice == Choice.ONE + assert Test(choice=Choice.THREE).choice == Choice.THREE + assert Test(choice=Choice.FOUR).choice == Choice.FOUR + + +def test_enum_set_with_int(): + assert Test(choice=0).choice == Choice.ZERO + assert Test(choice=1).choice == Choice.ONE + assert Test(choice=3).choice == Choice.THREE + assert Test(choice=4).choice == Choice.FOUR + + +def test_enum_is_comparable_with_int(): + assert Test(choice=Choice.ZERO).choice == 0 + assert Test(choice=Choice.ONE).choice == 1 + assert Test(choice=Choice.THREE).choice == 3 + assert Test(choice=Choice.FOUR).choice == 4 + + +def test_enum_to_dict(): + assert ( + "choice" not in Test(choice=Choice.ZERO).to_dict() + ), "Default enum value is not serialized" + assert ( + Test(choice=Choice.ZERO).to_dict(include_default_values=True)["choice"] + == "ZERO" + ) + assert Test(choice=Choice.ONE).to_dict()["choice"] == "ONE" + assert Test(choice=Choice.THREE).to_dict()["choice"] == "THREE" + assert Test(choice=Choice.FOUR).to_dict()["choice"] == "FOUR" + + +def test_repeated_enum_is_comparable_with_int(): + assert Test(choices=[Choice.ZERO]).choices == [0] + assert Test(choices=[Choice.ONE]).choices == [1] + assert Test(choices=[Choice.THREE]).choices == [3] + assert Test(choices=[Choice.FOUR]).choices == [4] + + +def test_repeated_enum_set_and_get(): + assert Test(choices=[Choice.ZERO]).choices == [Choice.ZERO] + assert Test(choices=[Choice.ONE]).choices == [Choice.ONE] + assert Test(choices=[Choice.THREE]).choices == [Choice.THREE] + assert Test(choices=[Choice.FOUR]).choices == [Choice.FOUR] + + +def test_repeated_enum_to_dict(): + assert Test(choices=[Choice.ZERO]).to_dict()["choices"] == ["ZERO"] + assert Test(choices=[Choice.ONE]).to_dict()["choices"] == ["ONE"] + assert Test(choices=[Choice.THREE]).to_dict()["choices"] == ["THREE"] + assert Test(choices=[Choice.FOUR]).to_dict()["choices"] == ["FOUR"] + + all_enums_dict = Test( + choices=[Choice.ZERO, Choice.ONE, Choice.THREE, Choice.FOUR] + ).to_dict() + assert (all_enums_dict["choices"]) == ["ZERO", "ONE", "THREE", "FOUR"] + + +def test_repeated_enum_with_single_value_to_dict(): + assert Test(choices=Choice.ONE).to_dict()["choices"] == ["ONE"] + assert Test(choices=1).to_dict()["choices"] == ["ONE"] + + +def test_repeated_enum_with_non_list_iterables_to_dict(): + assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"] + assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"] + assert Test(choices=(Choice.ONE, Choice.THREE)).to_dict()["choices"] == [ + "ONE", + "THREE", + ] + + def enum_generator(): + yield Choice.ONE + yield Choice.THREE + + assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"] diff --git a/tests/inputs/enums/enums.json b/tests/inputs/enums/enums.json deleted file mode 100644 index a4d009c8..00000000 --- a/tests/inputs/enums/enums.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "greeting": "HEY" -} diff --git a/tests/inputs/enums/enums.proto b/tests/inputs/enums/enums.proto deleted file mode 100644 index 421f78ab..00000000 --- a/tests/inputs/enums/enums.proto +++ /dev/null @@ -1,14 +0,0 @@ -syntax = "proto3"; - -// Enum for the different greeting types -enum Greeting { - HI = 0; - HEY = 1; - // Formal greeting - HELLO = 2; -} - -message Test { - // Greeting enum example - Greeting greeting = 1; -}