From 5d70c6cad81d846ee2d1b73c0400d02787ec9eff Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Fri, 28 Feb 2025 13:53:29 +0200 Subject: [PATCH] Fix test cases and add timeout tests in gloo_group_isolation Signed-off-by: Hollow Man --- .../test_gloo_group_isolation.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/python/ray/util/collective/tests/single_node_cpu_tests/test_gloo_group_isolation.py b/python/ray/util/collective/tests/single_node_cpu_tests/test_gloo_group_isolation.py index 43276fc0faaa2..fbe8ab3b7c974 100644 --- a/python/ray/util/collective/tests/single_node_cpu_tests/test_gloo_group_isolation.py +++ b/python/ray/util/collective/tests/single_node_cpu_tests/test_gloo_group_isolation.py @@ -1,4 +1,5 @@ from python.ray.util.collective.types import Backend +from python.ray.util.collective.collective_group.gloo_collective_group import GLOOGroup import ray import ray.util.collective as col import time @@ -9,18 +10,34 @@ class Worker: def __init__(self): pass - def init_gloo_group(rank: int, world_size: int, group_name: str): - col.init_collective_group(world_size, rank, Backend.GLOO, group_name) + def init_gloo_group( + self, world_size: int, rank: int, group_name: str, gloo_timeout: int = 30000 + ): + col.init_collective_group( + world_size, rank, Backend.GLOO, group_name, gloo_timeout + ) return True + def get_gloo_timeout(self, group_name: str) -> bool: + g = col.get_group_handle(group_name) + # Check if the group is initialized correctly + assert isinstance(g, GLOOGroup) + return g._gloo_context.getTimeout() + def test_two_groups_in_one_cluster(ray_start_regular_shared): + name1 = "name_1" + name2 = "name_2" + time1 = 40000 + time2 = 60000 w1 = Worker.remote() - ret1 = w1.init_gloo_group.remote(1, 0, "name_1") + ret1 = w1.init_gloo_group.remote(1, 0, name1, time1) w2 = Worker.remote() - ret2 = w2.init_gloo_group.remote(1, 0, "name_2") + ret2 = w2.init_gloo_group.remote(1, 0, name2, time2) assert ray.get(ret1) assert ray.get(ret2) + assert ray.get(w1.get_gloo_timeout.remote(name1)) == time1 + assert ray.get(w2.get_gloo_timeout.remote(name2)) == time2 def test_failure_when_initializing(shutdown_only):