Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: to_dict returns wrong enum fields when numbering is not consecutive #102

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 58 additions & 39 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import struct
import sys
import warnings
from abc import ABC
from base64 import b64decode, b64encode
from datetime import datetime, timedelta, timezone
Expand All @@ -21,6 +22,8 @@
get_type_hints,
)

import typing

from ._types import T
from .casing import camel_case, safe_snake_case, snake_case
from .grpc.grpclib_client import ServiceStub
Expand Down Expand Up @@ -251,7 +254,7 @@ def map_field(
)


class Enum(int, enum.Enum):
class Enum(enum.IntEnum):
"""Protocol buffers enumeration base class. Acts like `enum.IntEnum`."""

@classmethod
Expand Down Expand Up @@ -635,9 +638,13 @@ def __bytes__(self) -> bytes:

@classmethod
def _type_hint(cls, field_name: str) -> Type:
return cls._type_hints()[field_name]

@classmethod
def _type_hints(cls) -> Dict[str, Type]:
module = inspect.getmodule(cls)
type_hints = get_type_hints(cls, vars(module))
return type_hints[field_name]
return type_hints

@classmethod
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
Expand Down Expand Up @@ -789,55 +796,67 @@ def to_dict(
`False`.
"""
output: Dict[str, Any] = {}
field_types = self._type_hints()
for field_name, meta in self._betterproto.meta_by_field_name.items():
v = getattr(self, field_name)
field_type = field_types[field_name]
field_is_repeated = type(field_type) is type(typing.List)
value = getattr(self, field_name)
cased_name = casing(field_name).rstrip("_") # type: ignore
if meta.proto_type == "message":
if isinstance(v, datetime):
if v != DATETIME_ZERO or include_default_values:
output[cased_name] = _Timestamp.timestamp_to_json(v)
elif isinstance(v, timedelta):
if v != timedelta(0) or include_default_values:
output[cased_name] = _Duration.delta_to_json(v)
if meta.proto_type == TYPE_MESSAGE:
if isinstance(value, datetime):
if value != DATETIME_ZERO or include_default_values:
output[cased_name] = _Timestamp.timestamp_to_json(value)
elif isinstance(value, timedelta):
if value != timedelta(0) or include_default_values:
output[cased_name] = _Duration.delta_to_json(value)
elif meta.wraps:
if v is not None or include_default_values:
output[cased_name] = v
elif isinstance(v, list):
if value is not None or include_default_values:
output[cased_name] = value
elif field_is_repeated:
# Convert each item.
v = [i.to_dict(casing, include_default_values) for i in v]
if v or include_default_values:
output[cased_name] = v
value = [i.to_dict(casing, include_default_values) for i in value]
if value or include_default_values:
output[cased_name] = value
else:
if v._serialized_on_wire or include_default_values:
output[cased_name] = v.to_dict(casing, include_default_values)
elif meta.proto_type == "map":
for k in v:
if hasattr(v[k], "to_dict"):
v[k] = v[k].to_dict(casing, include_default_values)

if v or include_default_values:
output[cased_name] = v
elif v != self._get_field_default(field_name) or include_default_values:
if value._serialized_on_wire or include_default_values:
output[cased_name] = value.to_dict(
casing, include_default_values
)
elif meta.proto_type == TYPE_MAP:
for k in value:
if hasattr(value[k], "to_dict"):
value[k] = value[k].to_dict(casing, include_default_values)

if value or include_default_values:
output[cased_name] = value
elif value != self._get_field_default(field_name) or include_default_values:
if meta.proto_type in INT_64_TYPES:
if isinstance(v, list):
output[cased_name] = [str(n) for n in v]
if field_is_repeated:
output[cased_name] = [str(n) for n in value]
else:
output[cased_name] = str(v)
output[cased_name] = str(value)
elif meta.proto_type == TYPE_BYTES:
if isinstance(v, list):
output[cased_name] = [b64encode(b).decode("utf8") for b in v]
if field_is_repeated:
output[cased_name] = [
b64encode(b).decode("utf8") for b in value
]
else:
output[cased_name] = b64encode(v).decode("utf8")
output[cased_name] = b64encode(value).decode("utf8")
elif meta.proto_type == TYPE_ENUM:
enum_values = list(
self._betterproto.cls_by_field[field_name]
) # type: ignore
if isinstance(v, list):
output[cased_name] = [enum_values[e].name for e in v]
if field_is_repeated:
enum_class: Type[Enum] = field_type.__args__[0]
if isinstance(value, typing.Iterable) and not isinstance(
value, str
):
output[cased_name] = [enum_class(el).name for el in value]
else:
# transparently upgrade single value to repeated
output[cased_name] = [enum_class(value).name]
else:
output[cased_name] = enum_values[v].name
enum_class: Type[Enum] = field_type # noqa
output[cased_name] = enum_class(value).name
else:
output[cased_name] = v
output[cased_name] = value
return output

def from_dict(self: T, value: dict) -> T:
Expand Down
2 changes: 1 addition & 1 deletion tests/inputs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"namespace_keywords", # 70
"namespace_builtin_types", # 53
"googletypes_struct", # 9
"googletypes_value", # 9,
"googletypes_value", # 9
"import_capitalized_package",
"example", # This is the example in the readme. Not a test.
}
Expand Down
9 changes: 9 additions & 0 deletions tests/inputs/enum/enum.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"choice": "FOUR",
"choices": [
"ZERO",
"ONE",
"THREE",
"FOUR"
]
}
15 changes: 15 additions & 0 deletions tests/inputs/enum/enum.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
syntax = "proto3";

// Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values
message Test {
Choice choice = 1;
repeated Choice choices = 2;
}

enum Choice {
ZERO = 0;
ONE = 1;
// TWO = 2;
FOUR = 4;
THREE = 3;
}
84 changes: 84 additions & 0 deletions tests/inputs/enum/test_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from tests.output_betterproto.enum import (
Test,
Choice,
)


def test_enum_set_and_get():
assert Test(choice=Choice.ZERO).choice == Choice.ZERO
assert Test(choice=Choice.ONE).choice == Choice.ONE
assert Test(choice=Choice.THREE).choice == Choice.THREE
assert Test(choice=Choice.FOUR).choice == Choice.FOUR


def test_enum_set_with_int():
assert Test(choice=0).choice == Choice.ZERO
assert Test(choice=1).choice == Choice.ONE
assert Test(choice=3).choice == Choice.THREE
assert Test(choice=4).choice == Choice.FOUR


def test_enum_is_comparable_with_int():
assert Test(choice=Choice.ZERO).choice == 0
assert Test(choice=Choice.ONE).choice == 1
assert Test(choice=Choice.THREE).choice == 3
assert Test(choice=Choice.FOUR).choice == 4


def test_enum_to_dict():
assert (
"choice" not in Test(choice=Choice.ZERO).to_dict()
), "Default enum value is not serialized"
assert (
Test(choice=Choice.ZERO).to_dict(include_default_values=True)["choice"]
== "ZERO"
)
assert Test(choice=Choice.ONE).to_dict()["choice"] == "ONE"
assert Test(choice=Choice.THREE).to_dict()["choice"] == "THREE"
assert Test(choice=Choice.FOUR).to_dict()["choice"] == "FOUR"


def test_repeated_enum_is_comparable_with_int():
assert Test(choices=[Choice.ZERO]).choices == [0]
assert Test(choices=[Choice.ONE]).choices == [1]
assert Test(choices=[Choice.THREE]).choices == [3]
assert Test(choices=[Choice.FOUR]).choices == [4]


def test_repeated_enum_set_and_get():
assert Test(choices=[Choice.ZERO]).choices == [Choice.ZERO]
assert Test(choices=[Choice.ONE]).choices == [Choice.ONE]
assert Test(choices=[Choice.THREE]).choices == [Choice.THREE]
assert Test(choices=[Choice.FOUR]).choices == [Choice.FOUR]


def test_repeated_enum_to_dict():
assert Test(choices=[Choice.ZERO]).to_dict()["choices"] == ["ZERO"]
assert Test(choices=[Choice.ONE]).to_dict()["choices"] == ["ONE"]
assert Test(choices=[Choice.THREE]).to_dict()["choices"] == ["THREE"]
assert Test(choices=[Choice.FOUR]).to_dict()["choices"] == ["FOUR"]

all_enums_dict = Test(
choices=[Choice.ZERO, Choice.ONE, Choice.THREE, Choice.FOUR]
).to_dict()
assert (all_enums_dict["choices"]) == ["ZERO", "ONE", "THREE", "FOUR"]


def test_repeated_enum_with_single_value_to_dict():
assert Test(choices=Choice.ONE).to_dict()["choices"] == ["ONE"]
assert Test(choices=1).to_dict()["choices"] == ["ONE"]


def test_repeated_enum_with_non_list_iterables_to_dict():
assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"]
assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"]
assert Test(choices=(Choice.ONE, Choice.THREE)).to_dict()["choices"] == [
"ONE",
"THREE",
]

def enum_generator():
yield Choice.ONE
yield Choice.THREE

assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]
3 changes: 0 additions & 3 deletions tests/inputs/enums/enums.json

This file was deleted.

14 changes: 0 additions & 14 deletions tests/inputs/enums/enums.proto

This file was deleted.