Skip to content

Commit

Permalink
fix: Fix Timestamp, Duration and FieldMask marshaling in REST transpo…
Browse files Browse the repository at this point in the history
…rt (#334)

* fix: Fix Timestamp, Duration and FieldMask marshaling in REST transport

This fixes #333
  • Loading branch information
vam-google authored Aug 5, 2022
1 parent f85f470 commit a2e7300
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 0 deletions.
5 changes: 5 additions & 0 deletions proto/marshal/marshal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from google.protobuf import message
from google.protobuf import duration_pb2
from google.protobuf import timestamp_pb2
from google.protobuf import field_mask_pb2
from google.protobuf import struct_pb2
from google.protobuf import wrappers_pb2

Expand All @@ -31,6 +32,7 @@
from proto.marshal.rules import dates
from proto.marshal.rules import struct
from proto.marshal.rules import wrappers
from proto.marshal.rules import field_mask
from proto.primitives import ProtoType


Expand Down Expand Up @@ -126,6 +128,9 @@ def reset(self):
self.register(timestamp_pb2.Timestamp, dates.TimestampRule())
self.register(duration_pb2.Duration, dates.DurationRule())

# Register FieldMask wrappers.
self.register(field_mask_pb2.FieldMask, field_mask.FieldMaskRule())

# Register nullable primitive wrappers.
self.register(wrappers_pb2.BoolValue, wrappers.BoolValueRule())
self.register(wrappers_pb2.BytesValue, wrappers.BytesValueRule())
Expand Down
8 changes: 8 additions & 0 deletions proto/marshal/rules/dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def to_proto(self, value) -> timestamp_pb2.Timestamp:
seconds=int(value.timestamp()),
nanos=value.microsecond * 1000,
)
if isinstance(value, str):
timestamp_value = timestamp_pb2.Timestamp()
timestamp_value.FromJsonString(value=value)
return timestamp_value
return value


Expand Down Expand Up @@ -74,4 +78,8 @@ def to_proto(self, value) -> duration_pb2.Duration:
seconds=value.days * 86400 + value.seconds,
nanos=value.microseconds * 1000,
)
if isinstance(value, str):
duration_value = duration_pb2.Duration()
duration_value.FromJsonString(value=value)
return duration_value
return value
36 changes: 36 additions & 0 deletions proto/marshal/rules/field_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.protobuf import field_mask_pb2


class FieldMaskRule:
"""A marshal between FieldMask and strings.
See https://github.com/googleapis/proto-plus-python/issues/333
and
https://developers.google.com/protocol-buffers/docs/proto3#json
for more details.
"""

def to_python(self, value, *, absent: bool = None):
return value

def to_proto(self, value):
if isinstance(value, str):
field_mask_value = field_mask_pb2.FieldMask()
field_mask_value.FromJsonString(value=value)
return field_mask_value

return value
100 changes: 100 additions & 0 deletions tests/test_marshal_field_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.protobuf import field_mask_pb2

import proto
from proto.marshal.marshal import BaseMarshal


def test_field_mask_read():
class Foo(proto.Message):
mask = proto.Field(
proto.MESSAGE,
number=1,
message=field_mask_pb2.FieldMask,
)

foo = Foo(mask=field_mask_pb2.FieldMask(paths=["f.b.d", "f.c"]))

assert isinstance(foo.mask, field_mask_pb2.FieldMask)
assert foo.mask.paths == ["f.b.d", "f.c"]


def test_field_mask_write_string():
class Foo(proto.Message):
mask = proto.Field(
proto.MESSAGE,
number=1,
message=field_mask_pb2.FieldMask,
)

foo = Foo()
foo.mask = "f.b.d,f.c"

assert isinstance(foo.mask, field_mask_pb2.FieldMask)
assert foo.mask.paths == ["f.b.d", "f.c"]


def test_field_mask_write_pb2():
class Foo(proto.Message):
mask = proto.Field(
proto.MESSAGE,
number=1,
message=field_mask_pb2.FieldMask,
)

foo = Foo()
foo.mask = field_mask_pb2.FieldMask(paths=["f.b.d", "f.c"])

assert isinstance(foo.mask, field_mask_pb2.FieldMask)
assert foo.mask.paths == ["f.b.d", "f.c"]


def test_field_mask_absence():
class Foo(proto.Message):
mask = proto.Field(
proto.MESSAGE,
number=1,
message=field_mask_pb2.FieldMask,
)

foo = Foo()
assert not foo.mask.paths


def test_timestamp_del():
class Foo(proto.Message):
mask = proto.Field(
proto.MESSAGE,
number=1,
message=field_mask_pb2.FieldMask,
)

foo = Foo()
foo.mask = field_mask_pb2.FieldMask(paths=["f.b.d", "f.c"])

del foo.mask
assert not foo.mask.paths


def test_timestamp_to_python_idempotent():
# This path can never run in the current configuration because proto
# values are the only thing ever saved, and `to_python` is a read method.
#
# However, we test idempotency for consistency with `to_proto` and
# general resiliency.
marshal = BaseMarshal()
py_value = "f.b.d,f.c"
assert marshal.to_python(field_mask_pb2.FieldMask, py_value) is py_value
34 changes: 34 additions & 0 deletions tests/test_marshal_types_dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,24 @@ class Foo(proto.Message):
assert Foo.pb(foo).event_time.seconds == 1335020400


def test_timestamp_write_string():
class Foo(proto.Message):
event_time = proto.Field(
proto.MESSAGE,
number=1,
message=timestamp_pb2.Timestamp,
)

foo = Foo()
foo.event_time = "2012-04-21T15:00:00Z"
assert isinstance(foo.event_time, DatetimeWithNanoseconds)
assert isinstance(Foo.pb(foo).event_time, timestamp_pb2.Timestamp)
assert foo.event_time.year == 2012
assert foo.event_time.month == 4
assert foo.event_time.hour == 15
assert Foo.pb(foo).event_time.seconds == 1335020400


def test_timestamp_rmw_nanos():
class Foo(proto.Message):
event_time = proto.Field(
Expand Down Expand Up @@ -207,6 +225,22 @@ class Foo(proto.Message):
assert Foo.pb(foo).ttl.seconds == 120


def test_duration_write_string():
class Foo(proto.Message):
ttl = proto.Field(
proto.MESSAGE,
number=1,
message=duration_pb2.Duration,
)

foo = Foo()
foo.ttl = "120s"
assert isinstance(foo.ttl, timedelta)
assert isinstance(Foo.pb(foo).ttl, duration_pb2.Duration)
assert foo.ttl.seconds == 120
assert Foo.pb(foo).ttl.seconds == 120


def test_duration_del():
class Foo(proto.Message):
ttl = proto.Field(
Expand Down

0 comments on commit a2e7300

Please sign in to comment.