Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Wait for upstream object refs before starting task #81

Merged
merged 4 commits into from
Apr 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions prefect_ray/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ async def submit(
"The task runner must be started before submitting work."
)

call_kwargs = self._exchange_prefect_for_ray_futures(call.keywords)
call_kwargs, upstream_ray_obj_refs = self._exchange_prefect_for_ray_futures(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may be able to just extract upstream ray object refs in this step then do the exchange with the resolved objects in _run_prefect_task — seems feasible to do as a follow-up though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true. however, we need to exchange the prefect futures with something else as they are not serializable and cannot be passed to a ray remote function. I think we would need in any case some sort of placeholder object which is then exchanged by the resolved object in _run_prefect_task. but a ray object ref is already a good candidate for such a placeholder and saves us from yet another indirection.

call.keywords
)

remote_options = RemoteOptionsContext.get().current_remote_options
# Ray does not support the submission of async functions and we must create a
Expand All @@ -154,18 +156,22 @@ async def submit(
ray_decorator = ray.remote(**remote_options)
else:
ray_decorator = ray.remote

self._ray_refs[key] = ray_decorator(self._run_prefect_task).remote(
sync_compatible(call.func), **call_kwargs
sync_compatible(call.func), *upstream_ray_obj_refs, **call_kwargs
)

def _exchange_prefect_for_ray_futures(self, kwargs_prefect_futures):
"""Exchanges Prefect futures for Ray futures."""

upstream_ray_obj_refs = []

def exchange_prefect_for_ray_future(expr):
"""Exchanges Prefect future for Ray future."""
if isinstance(expr, PrefectFuture):
ray_future = self._ray_refs.get(expr.key)
if ray_future is not None:
upstream_ray_obj_refs.append(ray_future)
return ray_future
return expr

Expand All @@ -175,11 +181,17 @@ def exchange_prefect_for_ray_future(expr):
return_data=True,
)

return kwargs_ray_futures
return kwargs_ray_futures, upstream_ray_obj_refs

@staticmethod
def _run_prefect_task(func, *args, **kwargs):
"""Resolves Ray futures before calling the actual Prefect task function."""
def _run_prefect_task(func, *upstream_ray_obj_refs, **kwargs):
"""Resolves Ray futures before calling the actual Prefect task function.

Passing upstream_ray_obj_refs directly as args enables Ray to wait for
upstream tasks before running this remote function.
This variable is otherwise unused as the ray object refs are also
contained in kwargs.
"""

def resolve_ray_future(expr):
"""Resolves Ray future."""
Expand All @@ -189,7 +201,7 @@ def resolve_ray_future(expr):

kwargs = visit_collection(kwargs, visit_fn=resolve_ray_future, return_data=True)

return func(*args, **kwargs)
return func(**kwargs)

async def wait(self, key: UUID, timeout: float = None) -> Optional[State]:
ref = self._get_ray_ref(key)
Expand Down