From e9abbe47429a26b0e96bbc6106aa87d0dd45e0b7 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 10 Apr 2024 08:38:30 +1000 Subject: [PATCH] Don't pickle coroutine state twice --- src/dispatch/proto.py | 13 ++++--------- src/dispatch/scheduler.py | 2 +- tests/test_fastapi.py | 18 +++++++++++------- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/dispatch/proto.py b/src/dispatch/proto.py index c3fe286f..c5f9fdb8 100644 --- a/src/dispatch/proto.py +++ b/src/dispatch/proto.py @@ -72,11 +72,7 @@ def __init__(self, req: function_pb.RunRequest): else: self._input = _pb_any_unpack(req.input) else: - state_bytes = req.poll_result.coroutine_state - if len(state_bytes) > 0: - self._coroutine_state = pickle.loads(state_bytes) - else: - self._coroutine_state = None + self._coroutine_state = req.poll_result.coroutine_state self._call_results = [ CallResult._from_proto(r) for r in req.poll_result.results ] @@ -143,7 +139,7 @@ def from_input_arguments(cls, function: str, *args, **kwargs): def from_poll_results( cls, function: str, - coroutine_state: Any, + coroutine_state: Optional[bytes], call_results: List[CallResult], error: Optional[Error] = None, ): @@ -220,7 +216,7 @@ def exit( @classmethod def poll( cls, - state: Any, + coroutine_state: Optional[bytes] = None, calls: Optional[List[Call]] = None, min_results: int = 1, max_results: int = 10, @@ -229,14 +225,13 @@ def poll( """Suspend the function with a set of Calls, instructing the orchestrator to resume the function with the provided state when call results are ready.""" - state_bytes = pickle.dumps(state) max_wait = ( duration_pb2.Duration(seconds=max_wait_seconds) if max_wait_seconds is not None else None ) poll = poll_pb.Poll( - coroutine_state=state_bytes, + coroutine_state=coroutine_state, min_results=min_results, max_results=max_results, max_wait=max_wait, diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index 90ee5b68..ab812200 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -558,7 +558,7 @@ def _run(self, input: Input) -> Output: len(serialized_state), ) return Output.poll( - state=serialized_state, + coroutine_state=serialized_state, calls=pending_calls, min_results=max(1, min(state.outstanding_calls, self.poll_min_results)), max_results=max(1, min(state.outstanding_calls, self.poll_max_results)), diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 39eef0d0..1bbac55c 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -1,6 +1,7 @@ import base64 import os import pickle +import struct import unittest from typing import Any from unittest import mock @@ -282,7 +283,7 @@ def test_error_on_access_input_in_second_call(self): @self.dispatch.primitive_function def my_function(input: Input) -> Output: if input.is_first_call: - return Output.poll(state=42) + return Output.poll(coroutine_state=b"42") try: print(input.input) except ValueError: @@ -294,7 +295,7 @@ def my_function(input: Input) -> Output: return Output.value("not reached") resp = self.execute(my_function, input="cool stuff") - self.assertEqual(42, pickle.loads(resp.poll.coroutine_state)) + self.assertEqual(b"42", resp.poll.coroutine_state) resp = self.execute(my_function, state=resp.poll.coroutine_state) self.assertEqual("ValueError", resp.exit.result.error.type) @@ -337,11 +338,12 @@ def coroutine3(input: Input) -> Output: if input.is_first_call: counter = input.input else: - counter = input.coroutine_state + (counter,) = struct.unpack("@i", input.coroutine_state) counter -= 1 if counter <= 0: return Output.value("done") - return Output.poll(state=counter) + coroutine_state = struct.pack("@i", counter) + return Output.poll(coroutine_state=coroutine_state) # first call resp = self.execute(coroutine3, input=4) @@ -375,9 +377,10 @@ def coroutine_main(input: Input) -> Output: if input.is_first_call: text: str = input.input return Output.poll( - state=text, calls=[coro_compute_len._build_primitive_call(text)] + coroutine_state=text.encode(), + calls=[coro_compute_len._build_primitive_call(text)], ) - text = input.coroutine_state + text = input.coroutine_state.decode() length = input.call_results[0].output return Output.value(f"length={length} text='{text}'") @@ -415,7 +418,8 @@ def coroutine_main(input: Input) -> Output: if input.is_first_call: text: str = input.input return Output.poll( - state=text, calls=[coro_compute_len._build_primitive_call(text)] + coroutine_state=text.encode(), + calls=[coro_compute_len._build_primitive_call(text)], ) error = input.call_results[0].error if error is not None: