Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][collective] Support customizing gloo timeout #50223

Merged
merged 3 commits into from
Mar 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/ray/util/collective/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
send_multigpu,
recv,
recv_multigpu,
get_group_handle,
)

__all__ = [
Expand All @@ -48,4 +49,5 @@
"send_multigpu",
"recv",
"recv_multigpu",
"get_group_handle",
]
70 changes: 46 additions & 24 deletions python/ray/util/collective/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def __init__(self):
self._name_group_map = {}
self._group_name_map = {}

def create_collective_group(self, backend, world_size, rank, group_name):
def create_collective_group(
self, backend, world_size, rank, group_name, gloo_timeout
):
"""The entry to create new collective groups in the manager.

Put the registration and the group information into the manager
Expand All @@ -66,6 +68,7 @@ def create_collective_group(self, backend, world_size, rank, group_name):
group_name,
store_type="ray_internal_kv",
device_type="tcp",
gloo_timeout=gloo_timeout,
)
self._name_group_map[group_name] = g
self._group_name_map[g] = group_name
Expand Down Expand Up @@ -118,7 +121,11 @@ def is_group_initialized(group_name):


def init_collective_group(
world_size: int, rank: int, backend=types.Backend.NCCL, group_name: str = "default"
world_size: int,
rank: int,
backend=types.Backend.NCCL,
group_name: str = "default",
gloo_timeout: int = 30000,
):
"""Initialize a collective group inside an actor process.

Expand All @@ -145,7 +152,9 @@ def init_collective_group(
assert world_size > 0
assert rank >= 0
assert rank < world_size
_group_mgr.create_collective_group(backend, world_size, rank, group_name)
_group_mgr.create_collective_group(
backend, world_size, rank, group_name, gloo_timeout
)


def create_collective_group(
Expand All @@ -154,6 +163,7 @@ def create_collective_group(
ranks: List[int],
backend=types.Backend.NCCL,
group_name: str = "default",
gloo_timeout: int = 30000,
):
"""Declare a list of actors as a collective group.

Expand Down Expand Up @@ -209,7 +219,7 @@ def create_collective_group(
actors_id = [a._ray_actor_id for a in actors]
# TODO (Dacheng): how do we recycle this name actor?
info = Info.options(name=name, lifetime="detached").remote()
ray.get([info.set_info.remote(actors_id, world_size, ranks, backend)])
ray.get([info.set_info.remote(actors_id, world_size, ranks, backend, gloo_timeout)])


# TODO (we need a declarative destroy() API here.)
Expand Down Expand Up @@ -267,7 +277,7 @@ def allreduce(tensor, group_name: str = "default", op=types.ReduceOp.SUM):
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
g = get_group_handle(group_name)
opts = types.AllReduceOptions
opts.reduceOp = op
g.allreduce([tensor], opts)
Expand All @@ -289,7 +299,7 @@ def allreduce_multigpu(
if not types.cupy_available():
raise RuntimeError("Multigpu calls requires NCCL and Cupy.")
_check_tensor_list_input(tensor_list)
g = _check_and_get_group(group_name)
g = get_group_handle(group_name)
opts = types.AllReduceOptions
opts.reduceOp = op
g.allreduce(tensor_list, opts)
Expand All @@ -304,7 +314,7 @@ def barrier(group_name: str = "default"):
Returns:
None
"""
g = _check_and_get_group(group_name)
g = get_group_handle(group_name)
g.barrier()


Expand All @@ -323,7 +333,7 @@ def reduce(
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
g = get_group_handle(group_name)

# check dst rank
_check_rank_valid(g, dst_rank)
Expand Down Expand Up @@ -358,7 +368,7 @@ def reduce_multigpu(
if not types.cupy_available():
raise RuntimeError("Multigpu calls requires NCCL and Cupy.")
_check_tensor_list_input(tensor_list)
g = _check_and_get_group(group_name)
g = get_group_handle(group_name)

# check dst rank
_check_rank_valid(g, dst_rank)
Expand All @@ -382,7 +392,7 @@ def broadcast(tensor, src_rank: int = 0, group_name: str = "default"):
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
g = get_group_handle(group_name)

# check src rank
_check_rank_valid(g, src_rank)
Expand All @@ -409,7 +419,7 @@ def broadcast_multigpu(
if not types.cupy_available():
raise RuntimeError("Multigpu calls requires NCCL and Cupy.")
_check_tensor_list_input(tensor_list)
g = _check_and_get_group(group_name)
g = get_group_handle(group_name)

# check src rank
_check_rank_valid(g, src_rank)
Expand All @@ -433,7 +443,7 @@ def allgather(tensor_list: list, tensor, group_name: str = "default"):
"""
_check_single_tensor_input(tensor)
_check_tensor_list_input(tensor_list)
g = _check_and_get_group(group_name)
g = get_group_handle(group_name)
if len(tensor_list) != g.world_size:
# Typically CLL lib requires len(tensor_list) >= world_size;
# Here we make it more strict: len(tensor_list) == world_size.
Expand Down Expand Up @@ -464,7 +474,7 @@ def allgather_multigpu(
raise RuntimeError("Multigpu calls requires NCCL and Cupy.")
_check_tensor_lists_input(output_tensor_lists)
_check_tensor_list_input(input_tensor_list)
g = _check_and_get_group(group_name)
g = get_group_handle(group_name)
opts = types.AllGatherOptions()
g.allgather(output_tensor_lists, input_tensor_list, opts)

Expand All @@ -488,7 +498,7 @@ def reducescatter(
"""
_check_single_tensor_input(tensor)
_check_tensor_list_input(tensor_list)
g = _check_and_get_group(group_name)
g = get_group_handle(group_name)
if len(tensor_list) != g.world_size:
raise RuntimeError(
"The length of the tensor list operands to reducescatter "
Expand Down Expand Up @@ -522,7 +532,7 @@ def reducescatter_multigpu(
raise RuntimeError("Multigpu calls requires NCCL and Cupy.")
_check_tensor_lists_input(input_tensor_lists)
_check_tensor_list_input(output_tensor_list)
g = _check_and_get_group(group_name)
g = get_group_handle(group_name)
opts = types.ReduceScatterOptions()
opts.reduceOp = op
g.reducescatter(output_tensor_list, input_tensor_lists, opts)
Expand All @@ -540,7 +550,7 @@ def send(tensor, dst_rank: int, group_name: str = "default"):
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
g = get_group_handle(group_name)
_check_rank_valid(g, dst_rank)
if dst_rank == g.rank:
raise RuntimeError("The destination rank '{}' is self.".format(dst_rank))
Expand Down Expand Up @@ -575,7 +585,7 @@ def send_multigpu(
if not types.cupy_available():
raise RuntimeError("send_multigpu call requires NCCL.")
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
g = get_group_handle(group_name)
_check_rank_valid(g, dst_rank)
if dst_rank == g.rank:
raise RuntimeError(
Expand Down Expand Up @@ -603,7 +613,7 @@ def recv(tensor, src_rank: int, group_name: str = "default"):
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
g = get_group_handle(group_name)
_check_rank_valid(g, src_rank)
if src_rank == g.rank:
raise RuntimeError("The destination rank '{}' is self.".format(src_rank))
Expand Down Expand Up @@ -636,7 +646,7 @@ def recv_multigpu(
if not types.cupy_available():
raise RuntimeError("recv_multigpu call requires NCCL.")
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
g = get_group_handle(group_name)
_check_rank_valid(g, src_rank)
if src_rank == g.rank:
raise RuntimeError(
Expand Down Expand Up @@ -668,8 +678,15 @@ def synchronize(gpu_id: int):
cp.cuda.Device(gpu_id).synchronize()


def _check_and_get_group(group_name):
"""Check the existence and return the group handle."""
def get_group_handle(group_name: str = "default"):
"""Check if the group is initialized and return the group handle.

Args:
group_name: the name of the collective group.

Returns:
The collective group handle.
"""
_check_inside_actor()
global _group_mgr
if not is_group_initialized(group_name):
Expand All @@ -679,11 +696,15 @@ def _check_and_get_group(group_name):
# get and create the group.
name = "info_" + group_name
mgr = ray.get_actor(name=name)
ids, world_size, rank, backend = ray.get(mgr.get_info.remote())
ids, world_size, rank, backend, gloo_timeout = ray.get(
mgr.get_info.remote()
)
worker = ray._private.worker.global_worker
id_ = worker.core_worker.get_actor_id()
r = rank[ids.index(id_)]
_group_mgr.create_collective_group(backend, world_size, r, group_name)
_group_mgr.create_collective_group(
backend, world_size, r, group_name, gloo_timeout
)
except ValueError as exc:
# check if this group is initialized using options()
if (
Expand All @@ -693,8 +714,9 @@ def _check_and_get_group(group_name):
rank = int(os.environ["collective_rank"])
world_size = int(os.environ["collective_world_size"])
backend = os.environ["collective_backend"]
gloo_timeout = os.getenv("collective_gloo_timeout", 30000)
_group_mgr.create_collective_group(
backend, world_size, rank, group_name
backend, world_size, rank, group_name, gloo_timeout
)
else:
raise RuntimeError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def __init__(
group_name,
store_type="ray_internal_kv",
device_type="tcp",
gloo_timeout=30000,
):
"""Init an GLOO collective group.

Expand All @@ -200,9 +201,12 @@ def __init__(
"file", "hash".
device_type: The device type to transport.
Optional: "tcp", "uv".
gloo_timeout: The timeout for GLOO rendezvous in ms.
Optional: int, default: 30000.
"""
super(GLOOGroup, self).__init__(world_size, rank, group_name)
self._gloo_context = gloo_util.create_gloo_context(self.rank, self.world_size)
self._gloo_context.setTimeout(gloo_timeout)
self._rendezvous = Rendezvous(
self.group_name, self._gloo_context, store_type, device_type
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from python.ray.util.collective.types import Backend
from python.ray.util.collective.collective_group.gloo_collective_group import GLOOGroup
import ray
import ray.util.collective as col
import time
Expand All @@ -9,18 +10,34 @@ class Worker:
def __init__(self):
pass

def init_gloo_group(rank: int, world_size: int, group_name: str):
col.init_collective_group(world_size, rank, Backend.GLOO, group_name)
def init_gloo_group(
self, world_size: int, rank: int, group_name: str, gloo_timeout: int = 30000
):
col.init_collective_group(
world_size, rank, Backend.GLOO, group_name, gloo_timeout
)
return True

def get_gloo_timeout(self, group_name: str) -> int:
g = col.get_group_handle(group_name)
# Check if the group is initialized correctly
assert isinstance(g, GLOOGroup)
return g._gloo_context.getTimeout()


def test_two_groups_in_one_cluster(ray_start_regular_shared):
name1 = "name_1"
name2 = "name_2"
time1 = 40000
time2 = 60000
w1 = Worker.remote()
ret1 = w1.init_gloo_group.remote(1, 0, "name_1")
ret1 = w1.init_gloo_group.remote(1, 0, name1, time1)
w2 = Worker.remote()
ret2 = w2.init_gloo_group.remote(1, 0, "name_2")
ret2 = w2.init_gloo_group.remote(1, 0, name2, time2)
assert ray.get(ret1)
assert ray.get(ret2)
assert ray.get(w1.get_gloo_timeout.remote(name1)) == time1
assert ray.get(w2.get_gloo_timeout.remote(name2)) == time2


def test_failure_when_initializing(shutdown_only):
Expand Down
6 changes: 4 additions & 2 deletions python/ray/util/collective/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,16 @@ def __init__(self):
self.world_size = -1
self.rank = -1
self.backend = None
self.gloo_timeout = 30000

def set_info(self, ids, world_size, rank, backend):
def set_info(self, ids, world_size, rank, backend, gloo_timeout):
"""Store collective information."""
self.ids = ids
self.world_size = world_size
self.rank = rank
self.backend = backend
self.gloo_timeout = gloo_timeout

def get_info(self):
"""Get previously stored collective information."""
return self.ids, self.world_size, self.rank, self.backend
return self.ids, self.world_size, self.rank, self.backend, self.gloo_timeout