Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Commit

Permalink
Fixes ray integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Aug 28, 2022
1 parent 951a4ca commit 479da70
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
2 changes: 1 addition & 1 deletion graph_adapter_tests/h_ray/test_h_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@pytest.fixture(scope="module")
def init():
ray.init(local_mode=True) # need local mode, else it can't seem to find the h_ray module.
ray.init()
yield "initialized"
ray.shutdown()

Expand Down
37 changes: 22 additions & 15 deletions hamilton/experimental/h_ray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import inspect
import logging
import typing

Expand All @@ -10,6 +11,22 @@
logger = logging.getLogger(__name__)


def raify(fn):
"""Makes the function into something ray-friendly.
This is necessary due to https://github.com/ray-project/ray/issues/28146.
@param fn: Function to make ray-friendly
@return: The ray-friendly version
"""
if isinstance(fn, functools.partial):

def new_fn(*args, **kwargs):
return fn(*args, **kwargs)

return new_fn
return fn


class RayGraphAdapter(base.HamiltonGraphAdapter, base.ResultMixin):
"""Class representing what's required to make Hamilton run on Ray
Expand Down Expand Up @@ -60,11 +77,7 @@ def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) ->
:param kwargs: the arguments that should be passed to it.
:return: returns a ray object reference.
"""
if isinstance(node.callable, functools.partial):
return functools.partial(
ray.remote(node.callable.func).remote, *node.callable.args, **node.callable.keywords
)(**kwargs)
return ray.remote(node.callable).remote(**kwargs)
return ray.remote(raify(node.callable)).remote(**kwargs)

def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any:
"""Builds the result and brings it back to this running process.
Expand Down Expand Up @@ -139,13 +152,7 @@ def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type)
return node_type == input_type

def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any:
"""Function that is called as we walk the graph to determine how to execute a hamilton function.
:param node: the node from the graph.
:param kwargs: the arguments that should be passed to it.
:return: returns a ray object reference.
"""
return workflow.step(node.callable).step(**kwargs)
return ray.remote(raify(node.callable)).bind(**kwargs)

def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any:
"""Builds the result and brings it back to this running process.
Expand All @@ -157,8 +164,8 @@ def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any:
for k, v in outputs.items():
logger.debug(f"Got output {k}, with type [{type(v)}].")
# need to wrap our result builder in a remote call and then pass in what we want to build from.
remote_combine = workflow.step(self.result_builder.build_result).step(**outputs)
result = remote_combine.run(
workflow_id=self.workflow_id
remote_combine = ray.remote(self.result_builder.build_result).bind(**outputs)
result = workflow.run(
remote_combine, workflow_id=self.workflow_id
) # this materializes the object locally
return result

0 comments on commit 479da70

Please sign in to comment.