diff --git a/truss-chains/tests/test_e2e.py b/truss-chains/tests/test_e2e.py index 122fe1b05..86da35eff 100644 --- a/truss-chains/tests/test_e2e.py +++ b/truss-chains/tests/test_e2e.py @@ -224,3 +224,49 @@ def test_numpy_chain(mode): response = service.run_remote({}) assert response.status_code == 200 print(response.json()) + + +@pytest.mark.asyncio +async def test_timeout(): + with ensure_kill_all(): + chain_root = TEST_ROOT / "timeout" / "timeout_chain.py" + with framework.import_target(chain_root, "TimeoutChain") as entrypoint: + options = definitions.PushOptionsLocalDocker( + chain_name="integration-test", use_local_chains_src=True + ) + service = deployment_client.push(entrypoint, options) + + url = service.run_remote_url.replace("host.docker.internal", "localhost") + time.sleep(1.0) # Wait for models to be ready. + + # Async. + response = requests.post(url, json={"use_sync": False}) + # print(response.content) + + assert response.status_code == 500 + error = definitions.RemoteErrorDetail.model_validate(response.json()["error"]) + error_str = error.format() + error_regex = r""" +Chainlet-Traceback \(most recent call last\): + File \".*?/timeout_chain\.py\", line \d+, in run_remote + result = await self\._dep.run_remote\(\) +TimeoutError: Timeout calling remote Chainlet `Dependency` \(0.5 seconds limit\)\. + """ + assert re.match(error_regex.strip(), error_str.strip(), re.MULTILINE), error_str + + # Sync: + sync_response = requests.post(url, json={"use_sync": True}) + assert sync_response.status_code == 500 + sync_error = definitions.RemoteErrorDetail.model_validate( + sync_response.json()["error"] + ) + sync_error_str = sync_error.format() + sync_error_regex = r""" +Chainlet-Traceback \(most recent call last\): + File \".*?/timeout_chain\.py\", line \d+, in run_remote + result = self\._dep_sync.run_remote\(\) +TimeoutError: Timeout calling remote Chainlet `DependencySync` \(0.5 seconds limit\)\. + """ + assert re.match( + sync_error_regex.strip(), sync_error_str.strip(), re.MULTILINE + ), sync_error_str diff --git a/truss-chains/tests/timeout/timeout_chain.py b/truss-chains/tests/timeout/timeout_chain.py new file mode 100644 index 000000000..5a4a61519 --- /dev/null +++ b/truss-chains/tests/timeout/timeout_chain.py @@ -0,0 +1,35 @@ +import asyncio +import time + +import truss_chains as chains + + +class Dependency(chains.ChainletBase): + async def run_remote(self) -> bool: + await asyncio.sleep(1) + return True + + +class DependencySync(chains.ChainletBase): + def run_remote(self) -> bool: + time.sleep(1) + return True + + +@chains.mark_entrypoint # ("My Chain Name") +class TimeoutChain(chains.ChainletBase): + def __init__( + self, + dep=chains.depends(Dependency, timeout_sec=0.5), + dep_sync=chains.depends(DependencySync, timeout_sec=0.5), + ): + self._dep = dep + self._dep_sync = dep_sync + + async def run_remote(self, use_sync: bool) -> None: + if use_sync: + result = self._dep_sync.run_remote() + print(result) + else: + result = await self._dep.run_remote() + print(result) diff --git a/truss-chains/truss_chains/definitions.py b/truss-chains/truss_chains/definitions.py index 3bb6a27e3..8e6a9ae97 100644 --- a/truss-chains/truss_chains/definitions.py +++ b/truss-chains/truss_chains/definitions.py @@ -396,7 +396,7 @@ def get_asset_spec(self) -> AssetSpec: return self.assets.get_spec() -DEFAULT_TIMEOUT_SEC = 600 +DEFAULT_TIMEOUT_SEC = 600.0 class RPCOptions(SafeModel): @@ -415,7 +415,7 @@ class RPCOptions(SafeModel): """ retries: int = 1 - timeout_sec: int = DEFAULT_TIMEOUT_SEC + timeout_sec: float = DEFAULT_TIMEOUT_SEC use_binary: bool = False diff --git a/truss-chains/truss_chains/public_api.py b/truss-chains/truss_chains/public_api.py index 90e27f58e..ecb1be35a 100644 --- a/truss-chains/truss_chains/public_api.py +++ b/truss-chains/truss_chains/public_api.py @@ -46,7 +46,7 @@ def depends_context() -> definitions.DeploymentContext: def depends( chainlet_cls: Type[framework.ChainletT], retries: int = 1, - timeout_sec: int = definitions.DEFAULT_TIMEOUT_SEC, + timeout_sec: float = definitions.DEFAULT_TIMEOUT_SEC, use_binary: bool = False, ) -> framework.ChainletT: """Sets a "symbolic marker" to indicate to the framework that a chainlet is a diff --git a/truss-chains/truss_chains/remote_chainlet/stub.py b/truss-chains/truss_chains/remote_chainlet/stub.py index 95a3c43d8..f556b6c36 100644 --- a/truss-chains/truss_chains/remote_chainlet/stub.py +++ b/truss-chains/truss_chains/remote_chainlet/stub.py @@ -329,7 +329,16 @@ def _rpc() -> bytes: utils.response_raise_errors(response, self.name) return response.content - response_bytes = retry(_rpc) + try: + response_bytes = retry(_rpc) + except httpx.ReadTimeout: + msg = ( + f"Timeout calling remote Chainlet `{self.name}` " + f"({self._service_descriptor.options.timeout_sec} seconds limit)." + ) + logging.warning(msg) + raise TimeoutError(msg) from None # Prune error stack trace (TMI). + if output_model: return self._response_to_pydantic(response_bytes, output_model) return self._response_to_json(response_bytes) @@ -357,7 +366,16 @@ async def _rpc() -> bytes: await utils.async_response_raise_errors(response, self.name) return await response.read() - response_bytes: bytes = await retry(_rpc) + try: + response_bytes: bytes = await retry(_rpc) + except asyncio.TimeoutError: + msg = ( + f"Timeout calling remote Chainlet `{self.name}` " + f"({self._service_descriptor.options.timeout_sec} seconds limit)." + ) + logging.warning(msg) + raise TimeoutError(msg) from None # Prune error stack trace (TMI). + if output_model: return self._response_to_pydantic(response_bytes, output_model) return self._response_to_json(response_bytes) @@ -375,7 +393,15 @@ async def _rpc() -> AsyncIterator[bytes]: await utils.async_response_raise_errors(response, self.name) return response.content.iter_any() - return await retry(_rpc) + try: + return await retry(_rpc) + except asyncio.TimeoutError: + msg = ( + f"Timeout calling remote Chainlet `{self.name}` " + f"({self._service_descriptor.options.timeout_sec} seconds limit)." + ) + logging.warning(msg) + raise TimeoutError(msg) from None # Prune error stack trace (TMI). StubT = TypeVar("StubT", bound=StubBase)