Skip to content

Commit

Permalink
[Core][1/N] Support send/recv in PyNCCL Groups (vllm-project#4988)
Browse files Browse the repository at this point in the history
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
  • Loading branch information
andoorve authored May 23, 2024
1 parent 2ba80be commit 5eda2ea
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 17 deletions.
75 changes: 69 additions & 6 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import torch
import torch.distributed

from vllm.distributed.communication_op import ( # noqa
graph_capture, tensor_model_parallel_all_reduce)
Expand Down Expand Up @@ -68,7 +69,7 @@ def test_pynccl():


@worker_fn_wrapper
def multiple_tp_worker_fn():
def multiple_allreduce_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
groups = [
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
Expand All @@ -92,14 +93,14 @@ def multiple_tp_worker_fn():

@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_tp():
def test_pynccl_multiple_allreduce():
# this tests pynccl for multiple tp groups, in a standalone way
# i.e. call `pynccl_comm.all_reduce` directly
distributed_run(multiple_tp_worker_fn, 4)
distributed_run(multiple_allreduce_worker_fn, 4)


@worker_fn_wrapper
def multiple_tp_with_vllm_worker_fn():
def multiple_allreduce_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
ensure_model_parallel_initialized(2, 2)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
Expand All @@ -118,10 +119,10 @@ def multiple_tp_with_vllm_worker_fn():

@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_tp_with_vllm():
def test_pynccl_multiple_allreduce_with_vllm():
# this tests pynccl for multiple tp groups, together with vllm
# i.e. call `tensor_model_parallel_all_reduce`
distributed_run(multiple_tp_with_vllm_worker_fn, 4)
distributed_run(multiple_allreduce_with_vllm_worker_fn, 4)


@worker_fn_wrapper
Expand Down Expand Up @@ -151,6 +152,68 @@ def test_pynccl_with_cudagraph():
distributed_run(worker_fn_with_cudagraph, 2)


@worker_fn_wrapper
def send_recv_worker_fn():
pynccl_comm = PyNcclCommunicator()
if pynccl_comm.rank == 0:
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)
else:
tensor = torch.empty(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True):
if pynccl_comm.rank == 0:
pynccl_comm.send(tensor)
else:
pynccl_comm.recv(tensor)
result = tensor.mean().cpu().item()
assert result == 1


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_send_recv():
distributed_run(send_recv_worker_fn, 2)


@worker_fn_wrapper
def multiple_send_recv_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
groups = [
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
torch.distributed.new_group(ranks=[1, 3], backend="gloo")
]
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
pynccl_comm = PyNcclCommunicator(group=group, device=device)
if torch.distributed.get_rank() == 0:
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
elif torch.distributed.get_rank() == 1:
tensor = 2 * torch.ones(
16, 1024, 1024, dtype=torch.float32, device=device)
else:
tensor = torch.empty(16,
1024,
1024,
dtype=torch.float32,
device=device)
with pynccl_comm.change_state(enable=True):
if torch.distributed.get_rank() in [0, 1]:
pynccl_comm.send(tensor)
else:
pynccl_comm.recv(tensor)
result = tensor.mean().cpu().item()
if torch.distributed.get_rank() in [0, 2]:
assert result == 1
else:
assert result == 2


@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_send_recv():
distributed_run(multiple_send_recv_worker_fn, 4)


def test_ncclGetUniqueId():
lib = NCCLLibrary()
unique_id = lib.ncclGetUniqueId()
Expand Down
18 changes: 12 additions & 6 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch.distributed import ProcessGroup

from .parallel_state import (get_cpu_world_group,
from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -54,13 +54,19 @@ def graph_capture():
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
pynccl_comm = get_tp_pynccl_communicator()
if pynccl_comm is None:
maybe_pynccl_context = nullcontext()
tp_pynccl_comm = get_tp_pynccl_communicator()
pp_pynccl_comm = get_pp_pynccl_communicator()
if not tp_pynccl_comm:
maybe_tp_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state(
maybe_tp_pynccl_context = tp_pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream())
with maybe_pynccl_context:
if not pp_pynccl_comm:
maybe_pp_pynccl_context = nullcontext()
else:
maybe_pp_pynccl_context = pp_pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream())
with maybe_tp_pynccl_context, maybe_pp_pynccl_context:
yield graph_capture_context


Expand Down
34 changes: 34 additions & 0 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,40 @@ def all_reduce(self,
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))

def send(self,
tensor: torch.Tensor,
dst: Optional[int] = None,
stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
if dst is None:
dst = (self.rank + 1) % self.world_size
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream))

def recv(self,
tensor: torch.Tensor,
src: Optional[int] = None,
stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
if src is None:
src = (self.rank - 1) % self.world_size
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))

@contextmanager
def change_state(self,
enable: Optional[bool] = None,
Expand Down
26 changes: 26 additions & 0 deletions vllm/distributed/device_communicators/pynccl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,22 @@ class NCCLLibrary:
ncclRedOp_t, ncclComm_t, cudaStream_t
]),

# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
Function("ncclSend", ncclResult_t, [
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
ncclComm_t, cudaStream_t
]),

# ncclResult_t ncclRecv(
# void* recvbuff, size_t count, ncclDataType_t datatype,
# int src, ncclComm_t comm, cudaStream_t stream);
Function("ncclRecv", ncclResult_t, [
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
ncclComm_t, cudaStream_t
]),

# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
Expand Down Expand Up @@ -248,6 +264,16 @@ def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
datatype, op, comm,
stream))

def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
dest, comm, stream))

def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
comm, stream))

def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))

Expand Down
34 changes: 29 additions & 5 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
_TP_CA_COMMUNICATOR = None
# Pipeline model parallel group that the current rank belongs to.
_PP_DEVICE_GROUP: Optional[ProcessGroup] = None
_PP_CPU_GROUP: Optional[ProcessGroup] = None
_PP_PYNCCL_COMMUNICATOR = None

# when people blindly call `torch.distributed.all_reduce` etc,
# it will use this group. It is initialized with the `backend`
Expand Down Expand Up @@ -55,6 +57,11 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable


def get_pp_pynccl_communicator():
global _PP_PYNCCL_COMMUNICATOR
return _PP_PYNCCL_COMMUNICATOR


def get_tp_pynccl_communicator():
global _TP_PYNCCL_COMMUNICATOR
return _TP_PYNCCL_COMMUNICATOR
Expand Down Expand Up @@ -180,10 +187,11 @@ def initialize_model_parallel(
_TP_CPU_GROUP = cpu_group

from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
_TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
group=_TP_CPU_GROUP,
device=_LOCAL_RANK,
)
if tensor_model_parallel_size > 1:
_TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
group=_TP_CPU_GROUP,
device=_LOCAL_RANK,
)

# Initialize a custom fast all-reduce implementation.
if _ENABLE_CUSTOM_ALL_REDUCE:
Expand All @@ -195,17 +203,26 @@ def initialize_model_parallel(
)

# Build the pipeline model-parallel groups.
global _PP_DEVICE_GROUP
global _PP_DEVICE_GROUP, _PP_CPU_GROUP
global _PP_PYNCCL_COMMUNICATOR
global _PP_GLOBAL_RANKS
assert _PP_DEVICE_GROUP is None, (
"pipeline model parallel group is already initialized")
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group = torch.distributed.new_group(ranks, backend=backend)
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if rank in ranks:
_PP_DEVICE_GROUP = group
_PP_CPU_GROUP = cpu_group
_PP_GLOBAL_RANKS = ranks

if pipeline_model_parallel_size > 1:
_PP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
group=_PP_CPU_GROUP,
device=_LOCAL_RANK,
)


def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
Expand Down Expand Up @@ -267,6 +284,13 @@ def get_pipeline_model_parallel_group():
return _PP_DEVICE_GROUP


def get_pipeline_model_parallel_cpu_group():
"""Get the pipeline model parallel cpu group the caller rank belongs to."""
assert _PP_CPU_GROUP is not None, (
"pipeline model parallel cpu group is not initialized")
return _PP_CPU_GROUP


def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return torch.distributed.get_world_size(
Expand Down

0 comments on commit 5eda2ea

Please sign in to comment.