diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index f7379098..fd9b32de 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -589,18 +589,20 @@ def __bytes__(self) -> bytes: # Being selected in a a group means this field is the one that is # currently set in a `oneof` group, so it must be serialized even # if the value is the default zero value. - selected_in_group = False - if meta.group and self._group_current[meta.group] == field_name: - selected_in_group = True + selected_in_group = ( + meta.group and self._group_current[meta.group] == field_name + ) - serialize_empty = False - if isinstance(value, Message) and value._serialized_on_wire: - # Empty messages can still be sent on the wire if they were - # set (or received empty). - serialize_empty = True + # Empty messages can still be sent on the wire if they were + # set (or received empty). + serialize_empty = isinstance(value, Message) and value._serialized_on_wire + + include_default_value_for_oneof = self._include_default_value_for_oneof( + field_name=field_name, meta=meta + ) if value == self._get_field_default(field_name) and not ( - selected_in_group or serialize_empty + selected_in_group or serialize_empty or include_default_value_for_oneof ): # Default (zero) values are not serialized. Two exceptions are # if this is the selected oneof item or if we know we have to @@ -629,6 +631,17 @@ def __bytes__(self) -> bytes: sv = _serialize_single(2, meta.map_types[1], v) output += _serialize_single(meta.number, meta.proto_type, sk + sv) else: + # If we have an empty string and we're including the default value for + # a oneof, make sure we serialize it. This ensures that the byte string + # output isn't simply an empty string. This also ensures that round trip + # serialization will keep `which_one_of` calls consistent. + if ( + isinstance(value, str) + and value == "" + and include_default_value_for_oneof + ): + serialize_empty = True + output += _serialize_single( meta.number, meta.proto_type, @@ -732,6 +745,13 @@ def _postprocess_single( return value + def _include_default_value_for_oneof( + self, field_name: str, meta: FieldMetadata + ) -> bool: + return ( + meta.group is not None and self._group_current.get(meta.group) == field_name + ) + def parse(self: T, data: bytes) -> T: """ Parse the binary encoded Protobuf into this message instance. This @@ -810,10 +830,22 @@ def to_dict( cased_name = casing(field_name).rstrip("_") # type: ignore if meta.proto_type == TYPE_MESSAGE: if isinstance(value, datetime): - if value != DATETIME_ZERO or include_default_values: + if ( + value != DATETIME_ZERO + or include_default_values + or self._include_default_value_for_oneof( + field_name=field_name, meta=meta + ) + ): output[cased_name] = _Timestamp.timestamp_to_json(value) elif isinstance(value, timedelta): - if value != timedelta(0) or include_default_values: + if ( + value != timedelta(0) + or include_default_values + or self._include_default_value_for_oneof( + field_name=field_name, meta=meta + ) + ): output[cased_name] = _Duration.delta_to_json(value) elif meta.wraps: if value is not None or include_default_values: @@ -823,19 +855,28 @@ def to_dict( value = [i.to_dict(casing, include_default_values) for i in value] if value or include_default_values: output[cased_name] = value - else: - if value._serialized_on_wire or include_default_values: - output[cased_name] = value.to_dict( - casing, include_default_values - ) - elif meta.proto_type == TYPE_MAP: + elif ( + value._serialized_on_wire + or include_default_values + or self._include_default_value_for_oneof( + field_name=field_name, meta=meta + ) + ): + output[cased_name] = value.to_dict(casing, include_default_values,) + elif meta.proto_type == "map": for k in value: if hasattr(value[k], "to_dict"): value[k] = value[k].to_dict(casing, include_default_values) if value or include_default_values: output[cased_name] = value - elif value != self._get_field_default(field_name) or include_default_values: + elif ( + value != self._get_field_default(field_name) + or include_default_values + or self._include_default_value_for_oneof( + field_name=field_name, meta=meta + ) + ): if meta.proto_type in INT_64_TYPES: if field_is_repeated: output[cased_name] = [str(n) for n in value] @@ -894,6 +935,8 @@ def from_dict(self: T, value: dict) -> T: elif meta.wraps: setattr(self, field_name, value[key]) else: + # NOTE: `from_dict` mutates the underlying message, so no + # assignment here is necessary. v.from_dict(value[key]) elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: v = getattr(self, field_name) @@ -919,8 +962,8 @@ def from_dict(self: T, value: dict) -> T: elif isinstance(v, str): v = enum_cls.from_string(v) - if v is not None: - setattr(self, field_name, v) + if v is not None: + setattr(self, field_name, v) return self def to_json(self, indent: Union[None, int, str] = None) -> str: diff --git a/tests/inputs/google_impl_behavior_equivalence/google_impl_behavior_equivalence.proto b/tests/inputs/google_impl_behavior_equivalence/google_impl_behavior_equivalence.proto new file mode 100644 index 00000000..31b6bd3e --- /dev/null +++ b/tests/inputs/google_impl_behavior_equivalence/google_impl_behavior_equivalence.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +message Foo{ + int64 bar = 1; +} + +message Test{ + oneof group{ + string string = 1; + int64 integer = 2; + Foo foo = 3; + } +} diff --git a/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py b/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py new file mode 100644 index 00000000..abe5d66d --- /dev/null +++ b/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py @@ -0,0 +1,55 @@ +import pytest + +from google.protobuf import json_format +import betterproto +from tests.output_betterproto.google_impl_behavior_equivalence import ( + Test, + Foo, +) +from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import ( + Test as ReferenceTest, + Foo as ReferenceFoo, +) + + +def test_oneof_serializes_similar_to_google_oneof(): + + tests = [ + (Test(string="abc"), ReferenceTest(string="abc")), + (Test(integer=2), ReferenceTest(integer=2)), + (Test(foo=Foo(bar=1)), ReferenceTest(foo=ReferenceFoo(bar=1))), + # Default values should also behave the same within oneofs + (Test(string=""), ReferenceTest(string="")), + (Test(integer=0), ReferenceTest(integer=0)), + (Test(foo=Foo(bar=0)), ReferenceTest(foo=ReferenceFoo(bar=0))), + ] + for message, message_reference in tests: + # NOTE: As of July 2020, MessageToJson inserts newlines in the output string so, + # just compare dicts + assert message.to_dict() == json_format.MessageToDict(message_reference) + + +def test_bytes_are_the_same_for_oneof(): + + message = Test(string="") + message_reference = ReferenceTest(string="") + + message_bytes = bytes(message) + message_reference_bytes = message_reference.SerializeToString() + + assert message_bytes == message_reference_bytes + + message2 = Test().parse(message_reference_bytes) + message_reference2 = ReferenceTest() + message_reference2.ParseFromString(message_reference_bytes) + + assert message == message2 + assert message_reference == message_reference2 + + # None of these fields were explicitly set BUT they should not actually be null + # themselves + assert isinstance(message.foo, Foo) + assert isinstance(message2.foo, Foo) + + assert isinstance(message_reference.foo, ReferenceFoo) + assert isinstance(message_reference2.foo, ReferenceFoo) diff --git a/tests/inputs/oneof_default_value_serialization/oneof_default_value_serialization.proto b/tests/inputs/oneof_default_value_serialization/oneof_default_value_serialization.proto new file mode 100644 index 00000000..44163c70 --- /dev/null +++ b/tests/inputs/oneof_default_value_serialization/oneof_default_value_serialization.proto @@ -0,0 +1,28 @@ +syntax = "proto3"; + +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; +import "google/protobuf/wrappers.proto"; + +message Message{ + int64 value = 1; +} + +message NestedMessage{ + int64 id = 1; + oneof value_type{ + Message wrapped_message_value = 2; + } +} + +message Test{ + oneof value_type { + bool bool_value = 1; + int64 int64_value = 2; + google.protobuf.Timestamp timestamp_value = 3; + google.protobuf.Duration duration_value = 4; + Message wrapped_message_value = 5; + NestedMessage wrapped_nested_message_value = 6; + google.protobuf.BoolValue wrapped_bool_value = 7; + } +} \ No newline at end of file diff --git a/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py b/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py new file mode 100644 index 00000000..0c928cb8 --- /dev/null +++ b/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py @@ -0,0 +1,74 @@ +import pytest +import datetime + +import betterproto +from tests.output_betterproto.oneof_default_value_serialization import ( + Test, + Message, + NestedMessage, +) + + +def assert_round_trip_serialization_works(message: Test) -> None: + assert betterproto.which_one_of(message, "value_type") == betterproto.which_one_of( + Test().from_json(message.to_json()), "value_type" + ) + + +def test_oneof_default_value_serialization_works_for_all_values(): + """ + Serialization from message with oneof set to default -> JSON -> message should keep + default value field intact. + """ + + test_cases = [ + Test(bool_value=False), + Test(int64_value=0), + Test( + timestamp_value=datetime.datetime( + year=1970, + month=1, + day=1, + hour=0, + minute=0, + tzinfo=datetime.timezone.utc, + ) + ), + Test(duration_value=datetime.timedelta(0)), + Test(wrapped_message_value=Message(value=0)), + # NOTE: Do NOT use betterproto.BoolValue here, it will cause JSON serialization + # errors. + # TODO: Do we want to allow use of BoolValue directly within a wrapped field or + # should we simply hard fail here? + Test(wrapped_bool_value=False), + ] + for message in test_cases: + assert_round_trip_serialization_works(message) + + +def test_oneof_no_default_values_passed(): + message = Test() + assert ( + betterproto.which_one_of(message, "value_type") + == betterproto.which_one_of(Test().from_json(message.to_json()), "value_type") + == ("", None) + ) + + +def test_oneof_nested_oneof_messages_are_serialized_with_defaults(): + """ + Nested messages with oneofs should also be handled + """ + message = Test( + wrapped_nested_message_value=NestedMessage( + id=0, wrapped_message_value=Message(value=0) + ) + ) + assert ( + betterproto.which_one_of(message, "value_type") + == betterproto.which_one_of(Test().from_json(message.to_json()), "value_type") + == ( + "wrapped_nested_message_value", + NestedMessage(id=0, wrapped_message_value=Message(value=0)), + ) + ) diff --git a/tests/inputs/oneof_enum/test_oneof_enum.py b/tests/inputs/oneof_enum/test_oneof_enum.py index fe21c435..e3eca13a 100644 --- a/tests/inputs/oneof_enum/test_oneof_enum.py +++ b/tests/inputs/oneof_enum/test_oneof_enum.py @@ -9,34 +9,36 @@ from tests.util import get_test_case_json_data -@pytest.mark.xfail def test_which_one_of_returns_enum_with_default_value(): """ returns first field when it is enum and set with default value """ message = Test() message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")) - assert message.move is None + + assert message.move == Move( + x=0, y=0 + ) # Proto3 will default this as there is no null assert message.signal == Signal.PASS assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS) -@pytest.mark.xfail def test_which_one_of_returns_enum_with_non_default_value(): """ returns first field when it is enum and set with non default value """ message = Test() message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")) - assert message.move is None - assert message.signal == Signal.PASS + assert message.move == Move( + x=0, y=0 + ) # Proto3 will default this as there is no null + assert message.signal == Signal.RESIGN assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN) -@pytest.mark.xfail def test_which_one_of_returns_second_field_when_set(): message = Test() message.from_json(get_test_case_json_data("oneof_enum")) assert message.move == Move(x=2, y=3) - assert message.signal == 0 + assert message.signal == Signal.PASS assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) diff --git a/tests/test_features.py b/tests/test_features.py index 0a8cb2d7..b5b38112 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -282,3 +282,38 @@ class TestParentMessage(betterproto.Message): "someDouble": 1.2, "someMessage": {"someOtherInt": 0}, } + + +def test_oneof_default_value_set_causes_writes_wire(): + @dataclass + class Foo(betterproto.Message): + bar: int = betterproto.int32_field(1, group="group1") + baz: str = betterproto.string_field(2, group="group1") + + def _round_trip_serialization(foo: Foo) -> Foo: + return Foo().parse(bytes(foo)) + + foo1 = Foo(bar=0) + foo2 = Foo(baz="") + foo3 = Foo() + + assert bytes(foo1) == b"\x08\x00" + assert ( + betterproto.which_one_of(foo1, "group1") + == betterproto.which_one_of(_round_trip_serialization(foo1), "group1") + == ("bar", 0) + ) + + assert bytes(foo2) == b"\x12\x00" # Baz is just an empty string + assert ( + betterproto.which_one_of(foo2, "group1") + == betterproto.which_one_of(_round_trip_serialization(foo2), "group1") + == ("baz", "") + ) + + assert bytes(foo3) == b"" + assert ( + betterproto.which_one_of(foo3, "group1") + == betterproto.which_one_of(_round_trip_serialization(foo3), "group1") + == ("", None) + )