diff --git a/src/dispatch/any.py b/src/dispatch/any.py index f4d8d33..92dda7d 100644 --- a/src/dispatch/any.py +++ b/src/dispatch/any.py @@ -8,6 +8,7 @@ import google.protobuf.duration_pb2 import google.protobuf.empty_pb2 import google.protobuf.message +import google.protobuf.struct_pb2 import google.protobuf.timestamp_pb2 import google.protobuf.wrappers_pb2 from google.protobuf import descriptor_pool, message_factory @@ -44,6 +45,12 @@ def marshal_any(value: Any) -> google.protobuf.any_pb2.Any: nanos = value.microseconds * 1000 value = google.protobuf.duration_pb2.Duration(seconds=seconds, nanos=nanos) + if isinstance(value, list) or isinstance(value, dict): + try: + value = as_struct_value(value) + except ValueError: + pass # fallthrough + if not isinstance(value, google.protobuf.message.Message): value = pickled_pb.Pickled(pickled_value=pickle.dumps(value)) @@ -64,24 +71,34 @@ def unmarshal_any(any: google.protobuf.any_pb2.Any) -> Any: if isinstance(proto, pickled_pb.Pickled): return pickle.loads(proto.pickled_value) + elif isinstance(proto, google.protobuf.empty_pb2.Empty): return None + elif isinstance(proto, google.protobuf.wrappers_pb2.BoolValue): return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.Int32Value): return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.Int64Value): return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.UInt32Value): return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.UInt64Value): return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.FloatValue): return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.DoubleValue): return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.StringValue): return proto.value + elif isinstance(proto, google.protobuf.wrappers_pb2.BytesValue): try: # Assume it's the legacy container for pickled values. @@ -89,9 +106,65 @@ def unmarshal_any(any: google.protobuf.any_pb2.Any) -> Any: except Exception as e: # Otherwise, return the literal bytes. return proto.value + elif isinstance(proto, google.protobuf.timestamp_pb2.Timestamp): return proto.ToDatetime(tzinfo=UTC) + elif isinstance(proto, google.protobuf.duration_pb2.Duration): return proto.ToTimedelta() + elif isinstance(proto, google.protobuf.struct_pb2.Value): + return from_struct_value(proto) + return proto + + +def as_struct_value(value: Any) -> google.protobuf.struct_pb2.Value: + if value is None: + null_value = google.protobuf.struct_pb2.NullValue.NULL_VALUE + return google.protobuf.struct_pb2.Value(null_value=null_value) + + elif isinstance(value, bool): + return google.protobuf.struct_pb2.Value(bool_value=value) + + elif isinstance(value, int) or isinstance(value, float): + return google.protobuf.struct_pb2.Value(number_value=float(value)) + + elif isinstance(value, str): + return google.protobuf.struct_pb2.Value(string_value=value) + + elif isinstance(value, list): + list_value = google.protobuf.struct_pb2.ListValue( + values=[as_struct_value(v) for v in value] + ) + return google.protobuf.struct_pb2.Value(list_value=list_value) + + elif isinstance(value, dict): + for key in value.keys(): + if not isinstance(key, str): + raise ValueError("unsupported object key") + + struct_value = google.protobuf.struct_pb2.Struct( + fields={k: as_struct_value(v) for k, v in value.items()} + ) + return google.protobuf.struct_pb2.Value(struct_value=struct_value) + + raise ValueError("unsupported value") + + +def from_struct_value(value: google.protobuf.struct_pb2.Value) -> Any: + if value.HasField("null_value"): + return None + elif value.HasField("bool_value"): + return value.bool_value + elif value.HasField("number_value"): + return value.number_value + elif value.HasField("string_value"): + return value.string_value + elif value.HasField("list_value"): + + return [from_struct_value(v) for v in value.list_value.values] + elif value.HasField("struct_value"): + return {k: from_struct_value(v) for k, v in value.struct_value.fields.items()} + else: + raise RuntimeError(f"invalid struct_pb2.Value: {value}") diff --git a/tests/dispatch/test_any.py b/tests/dispatch/test_any.py index 9db4436..28fd2e5 100644 --- a/tests/dispatch/test_any.py +++ b/tests/dispatch/test_any.py @@ -94,3 +94,18 @@ def test_unmarshal_protobuf_message(): ) assert message == unmarshal_any(boxed) + + +def test_unmarshal_json_like(): + value = { + "null": None, + "bool": True, + "int": 11, + "float": 3.14, + "string": "foo", + "list": [None, "abc", 1.23], + "object": {"a": ["b", "c"]}, + } + boxed = marshal_any(value) + assert "type.googleapis.com/google.protobuf.Value" == boxed.type_url + assert value == unmarshal_any(boxed)