diff --git a/graph_adapter_tests/h_ray/test_h_ray.py b/graph_adapter_tests/h_ray/test_h_ray.py index 0bb790bc..574170a9 100644 --- a/graph_adapter_tests/h_ray/test_h_ray.py +++ b/graph_adapter_tests/h_ray/test_h_ray.py @@ -14,7 +14,7 @@ @pytest.fixture(scope="module") def init(): - ray.init(local_mode=True) # need local mode, else it can't seem to find the h_ray module. + ray.init() yield "initialized" ray.shutdown() diff --git a/hamilton/experimental/h_ray.py b/hamilton/experimental/h_ray.py index b36ee5a4..ff98feab 100644 --- a/hamilton/experimental/h_ray.py +++ b/hamilton/experimental/h_ray.py @@ -1,4 +1,5 @@ import functools +import inspect import logging import typing @@ -10,6 +11,22 @@ logger = logging.getLogger(__name__) +def raify(fn): + """Makes the function into something ray-friendly. + This is necessary due to https://github.com/ray-project/ray/issues/28146. + + @param fn: Function to make ray-friendly + @return: The ray-friendly version + """ + if isinstance(fn, functools.partial): + + def new_fn(*args, **kwargs): + return fn(*args, **kwargs) + + return new_fn + return fn + + class RayGraphAdapter(base.HamiltonGraphAdapter, base.ResultMixin): """Class representing what's required to make Hamilton run on Ray @@ -60,11 +77,7 @@ def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> :param kwargs: the arguments that should be passed to it. :return: returns a ray object reference. """ - if isinstance(node.callable, functools.partial): - return functools.partial( - ray.remote(node.callable.func).remote, *node.callable.args, **node.callable.keywords - )(**kwargs) - return ray.remote(node.callable).remote(**kwargs) + return ray.remote(raify(node.callable)).remote(**kwargs) def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any: """Builds the result and brings it back to this running process. @@ -139,13 +152,7 @@ def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type) return node_type == input_type def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any: - """Function that is called as we walk the graph to determine how to execute a hamilton function. - - :param node: the node from the graph. - :param kwargs: the arguments that should be passed to it. - :return: returns a ray object reference. - """ - return workflow.step(node.callable).step(**kwargs) + return ray.remote(raify(node.callable)).bind(**kwargs) def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any: """Builds the result and brings it back to this running process. @@ -157,8 +164,8 @@ def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any: for k, v in outputs.items(): logger.debug(f"Got output {k}, with type [{type(v)}].") # need to wrap our result builder in a remote call and then pass in what we want to build from. - remote_combine = workflow.step(self.result_builder.build_result).step(**outputs) - result = remote_combine.run( - workflow_id=self.workflow_id + remote_combine = ray.remote(self.result_builder.build_result).bind(**outputs) + result = workflow.run( + remote_combine, workflow_id=self.workflow_id ) # this materializes the object locally return result