Skip to content

Commit

Permalink
Merge pull request #12 from ulasozguler/master
Browse files Browse the repository at this point in the history
Added `include_default_values` parameter to `to_dict` function
  • Loading branch information
danielgtaylor authored Jan 31, 2020
2 parents 559b883 + c0170f4 commit c78851b
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 10 deletions.
26 changes: 16 additions & 10 deletions betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,11 +704,16 @@ def parse(self: T, data: bytes) -> T:
def FromString(cls: Type[T], data: bytes) -> T:
return cls().parse(data)

def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
def to_dict(self, casing: Casing = Casing.CAMEL, include_default_values: bool = False) -> dict:
"""
Returns a dict representation of this message instance which can be
used to serialize to e.g. JSON. Defaults to camel casing for
compatibility but can be set to other modes.
`include_default_values` can be set to `True` to include default
values of fields. E.g. an `int32` type field with `0` value will
not be in returned dict if `include_default_values` is set to
`False`.
"""
output: Dict[str, Any] = {}
for field in dataclasses.fields(self):
Expand All @@ -717,28 +722,29 @@ def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
cased_name = casing(field.name).rstrip("_") # type: ignore
if meta.proto_type == "message":
if isinstance(v, datetime):
if v != DATETIME_ZERO:
if v != DATETIME_ZERO or include_default_values:
output[cased_name] = _Timestamp.timestamp_to_json(v)
elif isinstance(v, timedelta):
if v != timedelta(0):
if v != timedelta(0) or include_default_values:
output[cased_name] = _Duration.delta_to_json(v)
elif meta.wraps:
if v is not None:
if v is not None or include_default_values:
output[cased_name] = v
elif isinstance(v, list):
# Convert each item.
v = [i.to_dict(casing) for i in v]
v = [i.to_dict(casing, include_default_values) for i in v]
output[cased_name] = v
elif v._serialized_on_wire:
output[cased_name] = v.to_dict(casing)
else:
if v._serialized_on_wire or include_default_values:
output[cased_name] = v.to_dict(casing, include_default_values)
elif meta.proto_type == "map":
for k in v:
if hasattr(v[k], "to_dict"):
v[k] = v[k].to_dict(casing)
v[k] = v[k].to_dict(casing, include_default_values)

if v:
if v or include_default_values:
output[cased_name] = v
elif v != self._get_field_default(field, meta):
elif v != self._get_field_default(field, meta) or include_default_values:
if meta.proto_type in INT_64_TYPES:
if isinstance(v, list):
output[cased_name] = [str(n) for n in v]
Expand Down
92 changes: 92 additions & 0 deletions betterproto/tests/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,95 @@ class Request(betterproto.Message):
# Differentiate between not passed and the zero-value.
assert Request().parse(b"").flag == None
assert Request().parse(b"\n\x00").flag == False


def test_to_dict_default_values():
@dataclass
class TestMessage(betterproto.Message):
some_int: int = betterproto.int32_field(1)
some_double: float = betterproto.double_field(2)
some_str: str = betterproto.string_field(3)
some_bool: bool = betterproto.bool_field(4)

# Empty dict
test = TestMessage().from_dict({})

assert test.to_dict(include_default_values=True) == {
'someInt': 0,
'someDouble': 0.0,
'someStr': '',
'someBool': False
}

# All default values
test = TestMessage().from_dict({
'someInt': 0,
'someDouble': 0.0,
'someStr': '',
'someBool': False
})

assert test.to_dict(include_default_values=True) == {
'someInt': 0,
'someDouble': 0.0,
'someStr': '',
'someBool': False
}

# Some default and some other values
@dataclass
class TestMessage2(betterproto.Message):
some_int: int = betterproto.int32_field(1)
some_double: float = betterproto.double_field(2)
some_str: str = betterproto.string_field(3)
some_bool: bool = betterproto.bool_field(4)
some_default_int: int = betterproto.int32_field(5)
some_default_double: float = betterproto.double_field(6)
some_default_str: str = betterproto.string_field(7)
some_default_bool: bool = betterproto.bool_field(8)

test = TestMessage2().from_dict({
'someInt': 2,
'someDouble': 1.2,
'someStr': 'hello',
'someBool': True,
'someDefaultInt': 0,
'someDefaultDouble': 0.0,
'someDefaultStr': '',
'someDefaultBool': False
})

assert test.to_dict(include_default_values=True) == {
'someInt': 2,
'someDouble': 1.2,
'someStr': 'hello',
'someBool': True,
'someDefaultInt': 0,
'someDefaultDouble': 0.0,
'someDefaultStr': '',
'someDefaultBool': False
}

# Nested messages
@dataclass
class TestChildMessage(betterproto.Message):
some_other_int: int = betterproto.int32_field(1)

@dataclass
class TestParentMessage(betterproto.Message):
some_int: int = betterproto.int32_field(1)
some_double: float = betterproto.double_field(2)
some_message: TestChildMessage = betterproto.message_field(3)

test = TestParentMessage().from_dict({
'someInt': 0,
'someDouble': 1.2,
})

assert test.to_dict(include_default_values=True) == {
'someInt': 0,
'someDouble': 1.2,
'someMessage': {
'someOtherInt': 0
}
}

0 comments on commit c78851b

Please sign in to comment.