Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translate RPC timeout error message. #1313

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions truss-chains/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,49 @@ def test_numpy_chain(mode):
response = service.run_remote({})
assert response.status_code == 200
print(response.json())


@pytest.mark.asyncio
async def test_timeout():
with ensure_kill_all():
chain_root = TEST_ROOT / "timeout" / "timeout_chain.py"
with framework.import_target(chain_root, "TimeoutChain") as entrypoint:
options = definitions.PushOptionsLocalDocker(
chain_name="integration-test", use_local_chains_src=True
)
service = deployment_client.push(entrypoint, options)

url = service.run_remote_url.replace("host.docker.internal", "localhost")
time.sleep(1.0) # Wait for models to be ready.

# Async.
response = requests.post(url, json={"use_sync": False})
# print(response.content)

assert response.status_code == 500
error = definitions.RemoteErrorDetail.model_validate(response.json()["error"])
error_str = error.format()
error_regex = r"""
Chainlet-Traceback \(most recent call last\):
File \".*?/timeout_chain\.py\", line \d+, in run_remote
result = await self\._dep.run_remote\(\)
TimeoutError: Timeout calling remote Chainlet `Dependency` \(0.5 seconds limit\)\.
"""
assert re.match(error_regex.strip(), error_str.strip(), re.MULTILINE), error_str

# Sync:
sync_response = requests.post(url, json={"use_sync": True})
assert sync_response.status_code == 500
sync_error = definitions.RemoteErrorDetail.model_validate(
sync_response.json()["error"]
)
sync_error_str = sync_error.format()
sync_error_regex = r"""
Chainlet-Traceback \(most recent call last\):
File \".*?/timeout_chain\.py\", line \d+, in run_remote
result = self\._dep_sync.run_remote\(\)
TimeoutError: Timeout calling remote Chainlet `DependencySync` \(0.5 seconds limit\)\.
"""
assert re.match(
sync_error_regex.strip(), sync_error_str.strip(), re.MULTILINE
), sync_error_str
35 changes: 35 additions & 0 deletions truss-chains/tests/timeout/timeout_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import asyncio
import time

import truss_chains as chains


class Dependency(chains.ChainletBase):
async def run_remote(self) -> bool:
await asyncio.sleep(1)
return True


class DependencySync(chains.ChainletBase):
def run_remote(self) -> bool:
time.sleep(1)
return True


@chains.mark_entrypoint # ("My Chain Name")
class TimeoutChain(chains.ChainletBase):
def __init__(
self,
dep=chains.depends(Dependency, timeout_sec=0.5),
dep_sync=chains.depends(DependencySync, timeout_sec=0.5),
):
self._dep = dep
self._dep_sync = dep_sync

async def run_remote(self, use_sync: bool) -> None:
if use_sync:
result = self._dep_sync.run_remote()
print(result)
else:
result = await self._dep.run_remote()
print(result)
4 changes: 2 additions & 2 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def get_asset_spec(self) -> AssetSpec:
return self.assets.get_spec()


DEFAULT_TIMEOUT_SEC = 600
DEFAULT_TIMEOUT_SEC = 600.0


class RPCOptions(SafeModel):
Expand All @@ -415,7 +415,7 @@ class RPCOptions(SafeModel):
"""

retries: int = 1
timeout_sec: int = DEFAULT_TIMEOUT_SEC
timeout_sec: float = DEFAULT_TIMEOUT_SEC
use_binary: bool = False


Expand Down
2 changes: 1 addition & 1 deletion truss-chains/truss_chains/public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def depends_context() -> definitions.DeploymentContext:
def depends(
chainlet_cls: Type[framework.ChainletT],
retries: int = 1,
timeout_sec: int = definitions.DEFAULT_TIMEOUT_SEC,
timeout_sec: float = definitions.DEFAULT_TIMEOUT_SEC,
use_binary: bool = False,
) -> framework.ChainletT:
"""Sets a "symbolic marker" to indicate to the framework that a chainlet is a
Expand Down
32 changes: 29 additions & 3 deletions truss-chains/truss_chains/remote_chainlet/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,16 @@ def _rpc() -> bytes:
utils.response_raise_errors(response, self.name)
return response.content

response_bytes = retry(_rpc)
try:
response_bytes = retry(_rpc)
except httpx.ReadTimeout:
msg = (
f"Timeout calling remote Chainlet `{self.name}` "
f"({self._service_descriptor.options.timeout_sec} seconds limit)."
)
logging.warning(msg)
raise TimeoutError(msg) from None # Prune error stack trace (TMI).

if output_model:
return self._response_to_pydantic(response_bytes, output_model)
return self._response_to_json(response_bytes)
Expand Down Expand Up @@ -357,7 +366,16 @@ async def _rpc() -> bytes:
await utils.async_response_raise_errors(response, self.name)
return await response.read()

response_bytes: bytes = await retry(_rpc)
try:
response_bytes: bytes = await retry(_rpc)
except asyncio.TimeoutError:
msg = (
f"Timeout calling remote Chainlet `{self.name}` "
f"({self._service_descriptor.options.timeout_sec} seconds limit)."
)
logging.warning(msg)
raise TimeoutError(msg) from None # Prune error stack trace (TMI).

if output_model:
return self._response_to_pydantic(response_bytes, output_model)
return self._response_to_json(response_bytes)
Expand All @@ -375,7 +393,15 @@ async def _rpc() -> AsyncIterator[bytes]:
await utils.async_response_raise_errors(response, self.name)
return response.content.iter_any()

return await retry(_rpc)
try:
return await retry(_rpc)
except asyncio.TimeoutError:
msg = (
f"Timeout calling remote Chainlet `{self.name}` "
f"({self._service_descriptor.options.timeout_sec} seconds limit)."
)
logging.warning(msg)
raise TimeoutError(msg) from None # Prune error stack trace (TMI).


StubT = TypeVar("StubT", bound=StubBase)
Expand Down