Skip to content

Commit

Permalink
Don't pickle proto messages before wrapping as google.protobuf.Any
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Jun 25, 2024
1 parent bb6ec79 commit 516d006
Showing 1 changed file with 13 additions and 27 deletions.
40 changes: 13 additions & 27 deletions src/dispatch/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, req: function_pb.RunRequest):

self._has_input = req.HasField("input")
if self._has_input:
self._input = _pb_any_unpack(req.input)
self._input = _any_unpickle(req.input)
else:
if req.poll_result.coroutine_state:
raise IncompatibleStateError # coroutine_state is deprecated
Expand Down Expand Up @@ -141,7 +141,7 @@ def from_input_arguments(cls, function: str, *args, **kwargs):
return Input(
req=function_pb.RunRequest(
function=function,
input=_pb_any_pickle(input),
input=_any_pickle(input),
)
)

Expand All @@ -157,7 +157,7 @@ def from_poll_results(
req=function_pb.RunRequest(
function=function,
poll_result=poll_pb.PollResult(
typed_coroutine_state=_pb_any_pickle(coroutine_state),
typed_coroutine_state=_any_pickle(coroutine_state),
results=[result._as_proto() for result in call_results],
error=error._as_proto() if error else None,
),
Expand Down Expand Up @@ -241,7 +241,7 @@ def poll(
else None
)
poll = poll_pb.Poll(
typed_coroutine_state=_pb_any_pickle(coroutine_state),
typed_coroutine_state=_any_pickle(coroutine_state),
min_results=min_results,
max_results=max_results,
max_wait=max_wait,
Expand Down Expand Up @@ -279,7 +279,7 @@ class Call:
correlation_id: Optional[int] = None

def _as_proto(self) -> call_pb.Call:
input_bytes = _pb_any_pickle(self.input)
input_bytes = _any_pickle(self.input)
return call_pb.Call(
correlation_id=self.correlation_id,
endpoint=self.endpoint,
Expand All @@ -301,7 +301,7 @@ def _as_proto(self) -> call_pb.CallResult:
output_any = None
error_proto = None
if self.output is not None:
output_any = _pb_any_pickle(self.output)
output_any = _any_pickle(self.output)
if self.error is not None:
error_proto = self.error._as_proto()

Expand Down Expand Up @@ -440,31 +440,17 @@ def _as_proto(self) -> error_pb.Error:
)


def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
p = pickled_pb.Pickled()
any.Unpack(p)
return pickle.loads(p.pickled_value)

elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): # legacy container
b = google.protobuf.wrappers_pb2.BytesValue()
any.Unpack(b)
return pickle.loads(b.value)

elif not any.type_url and not any.value:
return None

raise InvalidArgumentError(f"unsupported pickled value container: {any.type_url}")


def _pb_any_pickle(value: Any) -> google.protobuf.any_pb2.Any:
p = pickled_pb.Pickled(pickled_value=pickle.dumps(value))
def _any_pickle(value: Any) -> google.protobuf.any_pb2.Any:
any = google.protobuf.any_pb2.Any()
any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
if isinstance(value, google.protobuf.message.Message):
any.Pack(value)
else:
p = pickled_pb.Pickled(pickled_value=pickle.dumps(value))
any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
return any


def _pb_any_unpack(any: google.protobuf.any_pb2.Any) -> Any:
def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
p = pickled_pb.Pickled()
any.Unpack(p)
Expand Down

0 comments on commit 516d006

Please sign in to comment.