diff --git a/dashboard/state_aggregator.py b/dashboard/state_aggregator.py index 3c640b30249a..99b08ca0e7e6 100644 --- a/dashboard/state_aggregator.py +++ b/dashboard/state_aggregator.py @@ -386,6 +386,7 @@ def _to_task_state(task_attempt: dict) -> dict: "language", "required_resources", "runtime_env_info", + "parent_task_id", ], ), (task_attempt, ["task_id", "attempt_number", "job_id"]), diff --git a/python/ray/experimental/state/common.py b/python/ray/experimental/state/common.py index 415fcdeef86b..7ded472f40c9 100644 --- a/python/ray/experimental/state/common.py +++ b/python/ray/experimental/state/common.py @@ -512,6 +512,8 @@ class TaskState(StateSchema): required_resources: dict = state_column(detail=True, filterable=False) #: The runtime environment information for the task. runtime_env_info: str = state_column(detail=True, filterable=False) + #: The parent task id. + parent_task_id: str = state_column(filterable=True) @dataclass(init=True) diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index bf56a6326bd3..f8ef7ffb882d 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -2004,6 +2004,39 @@ def verify(): print(list_tasks()) +def test_parent_task_id(shutdown_only): + """Test parent task id set up properly""" + ray.init(num_cpus=2) + + @ray.remote + def child(): + pass + + @ray.remote + def parent(): + ray.get(child.remote()) + + ray.get(parent.remote()) + + def verify(): + tasks = list_tasks() + assert len(tasks) == 2, "Expect 2 tasks to finished" + parent_task_id = None + child_parent_task_id = None + for task in tasks: + if task["func_or_class_name"] == "parent": + parent_task_id = task["task_id"] + elif task["func_or_class_name"] == "child": + child_parent_task_id = task["parent_task_id"] + + assert ( + parent_task_id == child_parent_task_id + ), "Child should have the parent task id" + return True + + wait_for_condition(verify) + + def test_list_get_task_multiple_attempt_all_failed(shutdown_only): ray.init(num_cpus=2) job_id = ray.get_runtime_context().get_job_id()