Skip to content

Commit

Permalink
Fixes danielgtaylor#93 to_dict returns wrong enum fields when numberi…
Browse files Browse the repository at this point in the history
…ng is not consecutive
  • Loading branch information
boukeversteegh committed Jul 11, 2020
1 parent 42e197f commit a945246
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 57 deletions.
98 changes: 59 additions & 39 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import struct
import sys
import warnings
from abc import ABC
from base64 import b64decode, b64encode
from datetime import datetime, timedelta, timezone
Expand All @@ -21,6 +22,8 @@
get_type_hints,
)

import typing

from ._types import T
from .casing import camel_case, safe_snake_case, snake_case
from .grpc.grpclib_client import ServiceStub
Expand Down Expand Up @@ -251,7 +254,7 @@ def map_field(
)


class Enum(int, enum.Enum):
class Enum(enum.IntEnum):
"""Protocol buffers enumeration base class. Acts like `enum.IntEnum`."""

@classmethod
Expand Down Expand Up @@ -635,9 +638,13 @@ def __bytes__(self) -> bytes:

@classmethod
def _type_hint(cls, field_name: str) -> Type:
return cls._type_hints()[field_name]

@classmethod
def _type_hints(cls) -> Dict[str, Type]:
module = inspect.getmodule(cls)
type_hints = get_type_hints(cls, vars(module))
return type_hints[field_name]
return type_hints

@classmethod
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
Expand Down Expand Up @@ -789,55 +796,68 @@ def to_dict(
`False`.
"""
output: Dict[str, Any] = {}
field_types = self._type_hints()
for field_name, meta in self._betterproto.meta_by_field_name.items():
v = getattr(self, field_name)
field_type = field_types[field_name]
field_is_repeated = type(field_type) is type(typing.List)
value = getattr(self, field_name)
cased_name = casing(field_name).rstrip("_") # type: ignore
if meta.proto_type == "message":
if isinstance(v, datetime):
if v != DATETIME_ZERO or include_default_values:
output[cased_name] = _Timestamp.timestamp_to_json(v)
elif isinstance(v, timedelta):
if v != timedelta(0) or include_default_values:
output[cased_name] = _Duration.delta_to_json(v)
if meta.proto_type == TYPE_MESSAGE:
if isinstance(value, datetime):
if value != DATETIME_ZERO or include_default_values:
output[cased_name] = _Timestamp.timestamp_to_json(value)
elif isinstance(value, timedelta):
if value != timedelta(0) or include_default_values:
output[cased_name] = _Duration.delta_to_json(value)
elif meta.wraps:
if v is not None or include_default_values:
output[cased_name] = v
elif isinstance(v, list):
if value is not None or include_default_values:
output[cased_name] = value
elif field_is_repeated:
# Convert each item.
v = [i.to_dict(casing, include_default_values) for i in v]
if v or include_default_values:
output[cased_name] = v
value = [i.to_dict(casing, include_default_values) for i in value]
if value or include_default_values:
output[cased_name] = value
else:
if v._serialized_on_wire or include_default_values:
output[cased_name] = v.to_dict(casing, include_default_values)
elif meta.proto_type == "map":
for k in v:
if hasattr(v[k], "to_dict"):
v[k] = v[k].to_dict(casing, include_default_values)

if v or include_default_values:
output[cased_name] = v
elif v != self._get_field_default(field_name) or include_default_values:
if value._serialized_on_wire or include_default_values:
output[cased_name] = value.to_dict(
casing, include_default_values
)
elif meta.proto_type == TYPE_MAP:
for k in value:
if hasattr(value[k], "to_dict"):
value[k] = value[k].to_dict(casing, include_default_values)

if value or include_default_values:
output[cased_name] = value
elif value != self._get_field_default(field_name) or include_default_values:
if meta.proto_type in INT_64_TYPES:
if isinstance(v, list):
output[cased_name] = [str(n) for n in v]
if field_is_repeated:
output[cased_name] = [str(n) for n in value]
else:
output[cased_name] = str(v)
output[cased_name] = str(value)
elif meta.proto_type == TYPE_BYTES:
if isinstance(v, list):
output[cased_name] = [b64encode(b).decode("utf8") for b in v]
if field_is_repeated:
output[cased_name] = [
b64encode(b).decode("utf8") for b in value
]
else:
output[cased_name] = b64encode(v).decode("utf8")
output[cased_name] = b64encode(value).decode("utf8")
elif meta.proto_type == TYPE_ENUM:
enum_values = list(
self._betterproto.cls_by_field[field_name]
) # type: ignore
if isinstance(v, list):
output[cased_name] = [enum_values[e].name for e in v]
if field_is_repeated:
enum_class = field_type.__args__[0]
if isinstance(value, typing.Iterable):
output[cased_name] = [
enum_class(element).name for element in value
]
else:
warnings.warn(
f"Non-iterable value for repeated enum field {field_name}"
)
else:
output[cased_name] = enum_values[v].name
enum_class = field_type
output[cased_name] = enum_class(value).name
else:
output[cased_name] = v
output[cased_name] = value
return output

def from_dict(self: T, value: dict) -> T:
Expand Down
2 changes: 1 addition & 1 deletion tests/inputs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"namespace_keywords", # 70
"namespace_builtin_types", # 53
"googletypes_struct", # 9
"googletypes_value", # 9,
"googletypes_value", # 9
"import_capitalized_package",
"example", # This is the example in the readme. Not a test.
}
Expand Down
9 changes: 9 additions & 0 deletions tests/inputs/enum/enum.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"choice": "FOUR",
"choices": [
"ZERO",
"ONE",
"THREE",
"FOUR"
]
}
15 changes: 15 additions & 0 deletions tests/inputs/enum/enum.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
syntax = "proto3";

// Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values
message Test {
Choice choice = 1;
repeated Choice choices = 2;
}

enum Choice {
ZERO = 0;
ONE = 1;
// TWO = 2;
FOUR = 4;
THREE = 3;
}
64 changes: 64 additions & 0 deletions tests/inputs/enum/test_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from tests.output_betterproto.enum import (
Test,
Choice,
)


def test_enum_set_and_get():
assert Test(choice=Choice.ZERO).choice == Choice.ZERO
assert Test(choice=Choice.ONE).choice == Choice.ONE
assert Test(choice=Choice.THREE).choice == Choice.THREE
assert Test(choice=Choice.FOUR).choice == Choice.FOUR


def test_enum_set_with_int():
assert Test(choice=0).choice == Choice.ZERO
assert Test(choice=1).choice == Choice.ONE
assert Test(choice=3).choice == Choice.THREE
assert Test(choice=4).choice == Choice.FOUR


def test_enum_is_comparable_with_int():
assert Test(choice=Choice.ZERO).choice == 0
assert Test(choice=Choice.ONE).choice == 1
assert Test(choice=Choice.THREE).choice == 3
assert Test(choice=Choice.FOUR).choice == 4


def test_enum_to_dict():
assert (
"choice" not in Test(choice=Choice.ZERO).to_dict()
), "Default enum value is not serialized"
assert (
Test(choice=Choice.ZERO).to_dict(include_default_values=True)["choice"]
== "ZERO"
)
assert Test(choice=Choice.ONE).to_dict()["choice"] == "ONE"
assert Test(choice=Choice.THREE).to_dict()["choice"] == "THREE"
assert Test(choice=Choice.FOUR).to_dict()["choice"] == "FOUR"


def test_repeated_enum_is_comparable_with_int():
assert Test(choices=[Choice.ZERO]).choices == [0]
assert Test(choices=[Choice.ONE]).choices == [1]
assert Test(choices=[Choice.THREE]).choices == [3]
assert Test(choices=[Choice.FOUR]).choices == [4]


def test_repeated_enum_set_and_get():
assert Test(choices=[Choice.ZERO]).choices == [Choice.ZERO]
assert Test(choices=[Choice.ONE]).choices == [Choice.ONE]
assert Test(choices=[Choice.THREE]).choices == [Choice.THREE]
assert Test(choices=[Choice.FOUR]).choices == [Choice.FOUR]


def test_repeated_enum_to_dict():
assert Test(choices=[Choice.ZERO]).to_dict()["choices"] == ["ZERO"]
assert Test(choices=[Choice.ONE]).to_dict()["choices"] == ["ONE"]
assert Test(choices=[Choice.THREE]).to_dict()["choices"] == ["THREE"]
assert Test(choices=[Choice.FOUR]).to_dict()["choices"] == ["FOUR"]

all_enums_dict = Test(
choices=[Choice.ZERO, Choice.ONE, Choice.THREE, Choice.FOUR]
).to_dict()
assert (all_enums_dict["choices"]) == ["ZERO", "ONE", "THREE", "FOUR"]
3 changes: 0 additions & 3 deletions tests/inputs/enums/enums.json

This file was deleted.

14 changes: 0 additions & 14 deletions tests/inputs/enums/enums.proto

This file was deleted.

0 comments on commit a945246

Please sign in to comment.