Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Distributed] refactor pynccl to hold multiple communicators #4591

Merged
merged 44 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
9c6130a
add cache for loading the same library multiple times
youkaichao May 3, 2024
1493243
refactor code
youkaichao May 3, 2024
cadcd02
fix import
youkaichao May 3, 2024
7918798
remove pynccl_utils.init_process_group
youkaichao May 3, 2024
5924038
remove pynccl_utils.is_initialized
youkaichao May 3, 2024
813b047
remove pynccl_utils.destroy_process_group
youkaichao May 3, 2024
b244e6c
remove pynccl_utils.get_world_size
youkaichao May 3, 2024
7e15c98
remove pynccl_utils.get_nccl_backend
youkaichao May 3, 2024
e610f64
remove is_pynccl_enabled_for_all_reduce
youkaichao May 3, 2024
8480995
remove _ENABLE_PYNCCL_FOR_ALL_REDUCE
youkaichao May 3, 2024
5ed6f07
remove set_pynccl_stream
youkaichao May 3, 2024
8134287
remove pynccl utils
youkaichao May 3, 2024
e65e9ef
fix state
youkaichao May 3, 2024
c8b6fc0
fix test
youkaichao May 3, 2024
c7a2f0c
fix import
youkaichao May 3, 2024
75a8d11
move warmup into pynccl
youkaichao May 4, 2024
59c064e
add device
youkaichao May 4, 2024
16aeef1
fix device for allreduce warmup
youkaichao May 4, 2024
4710fc3
improve ways of discovering default local rank
youkaichao May 4, 2024
c8542ec
make sure warmup happens in stream
youkaichao May 4, 2024
b2d2661
add disable
youkaichao May 4, 2024
67d1d9a
do not init when world size is 1
youkaichao May 4, 2024
c86199c
fix initial state of pynccl allreduce
youkaichao May 4, 2024
0030a31
add comments
youkaichao May 4, 2024
49f6d91
add context manager
youkaichao May 4, 2024
38b148b
refactor logic of available
youkaichao May 4, 2024
d241480
non-intrusive code
youkaichao May 4, 2024
d7209f1
clean up pynccl enable or disable
youkaichao May 4, 2024
7b55026
fix isort
youkaichao May 4, 2024
ee734b1
fix stream attribute
youkaichao May 4, 2024
0516956
fix import
youkaichao May 9, 2024
9f63bf8
rename to PyNcclCommunicator and pynccl_comm
youkaichao May 9, 2024
e9aa766
rename use_pynccl_allreduce
youkaichao May 9, 2024
0f64301
fix lint
youkaichao May 9, 2024
a64962e
fix lint
youkaichao May 9, 2024
d2f83ba
fix lint
youkaichao May 9, 2024
12f309b
fix dependency on custom_all_reduce
youkaichao May 9, 2024
68e448c
fix lint
youkaichao May 9, 2024
ad6f840
use _PP_DEVICE_GROUP
youkaichao May 9, 2024
e2153b2
use _PP_GLOBAL_RANKS
youkaichao May 9, 2024
80aca94
fix lint
youkaichao May 9, 2024
c1b1cdb
use change_state rather than enable
youkaichao May 9, 2024
c4e3b0f
Merge branch 'main' into bind_pynccl_to_group
youkaichao May 9, 2024
70a7e26
add get_tp_pynccl_communicator
youkaichao May 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 27 additions & 24 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import pytest
import torch

import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group,
init_distributed_environment, with_pynccl_for_all_reduce)
from vllm.distributed.communication_op import (
tensor_model_parallel_all_reduce, use_pynccl_allreduce)
from vllm.distributed.device_communicators.pynccl import NCCLCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.utils import update_environment_variables


Expand Down Expand Up @@ -41,6 +40,10 @@ def worker_fn_wrapper(fn):
# and update the environment variables in the function
def wrapped_fn(env):
update_environment_variables(env)
import os
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
local_rank = os.environ['LOCAL_RANK']
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
init_distributed_environment()
fn()

Expand All @@ -51,7 +54,8 @@ def wrapped_fn(env):
def worker_fn():
comm = NCCLCommunicator()
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
comm.all_reduce(tensor)
with comm.enable():
comm.all_reduce(tensor)
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
result = tensor.mean().cpu().item()
assert result == comm.world_size

Expand All @@ -72,16 +76,17 @@ def multiple_tp_worker_fn():
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
comm = NCCLCommunicator(group=group, device=device)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
comm.all_reduce(tensor)
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 4
else:
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 2
with comm.enable():
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
comm.all_reduce(tensor)
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 4
else:
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 2


@pytest.mark.skipif(torch.cuda.device_count() < 4,
Expand All @@ -95,12 +100,9 @@ def test_pynccl_multiple_tp():
@worker_fn_wrapper
def multiple_tp_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
torch.cuda.set_device(torch.distributed.get_rank())
ensure_model_parallel_initialized(2, 2)
pynccl_utils.init_process_group(
group=get_tensor_model_parallel_cpu_group())
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with with_pynccl_for_all_reduce():
with use_pynccl_allreduce():
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
# two tp groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor)
Expand Down Expand Up @@ -129,7 +131,7 @@ def worker_fn_with_cudagraph():
# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{comm.rank}')
torch.cuda.synchronize()
with torch.cuda.graph(graph, stream=comm.stream):
with torch.cuda.graph(graph, stream=comm.stream), comm.enable():
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
# operation during the graph capture is recorded but not executed
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
comm.all_reduce(a)
Expand All @@ -147,7 +149,8 @@ def test_pynccl_with_cudagraph():


def test_ncclGetUniqueId():
unique_id = ncclGetUniqueId()
lib = NCCLLibrary()
unique_id = lib.ncclGetUniqueId()
# `list(unique_id.internal)` is something like this:
# [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
Expand Down
24 changes: 19 additions & 5 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import namedtuple
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -7,8 +8,20 @@
from .parallel_state import (get_cpu_world_group,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
is_pynccl_enabled_for_all_reduce)
get_tensor_model_parallel_world_size)


@contextmanager
def use_pynccl_allreduce():
from vllm.distributed.device_communicators import custom_all_reduce
if not custom_all_reduce.is_initialized():
from vllm.distributed.parallel_state import _TP_PYNCCL_COMMUNICATOR
assert _TP_PYNCCL_COMMUNICATOR is not None
with _TP_PYNCCL_COMMUNICATOR.enable(
stream=torch.cuda.current_stream()):
yield
else:
yield
youkaichao marked this conversation as resolved.
Show resolved Hide resolved


def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
Expand All @@ -23,18 +36,19 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
TLDR: always assume this function modifies its input, but use the return
value as the output.
"""
from vllm.distributed.device_communicators import pynccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import (
custom_all_reduce)
from vllm.distributed.parallel_state import _TP_PYNCCL_COMMUNICATOR
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb question: Why don't we import these at the top of the module?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if the variable is used outside the module where it's defined, the variable name should not start with _. If you want to make this private, why don't we use getters like get_tp_pynccl_communicator instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we import these at the top of the module?

Because this variable is None when the module gets initialized. Import these at the top of the module will get None. It will not be updated as the variable in the vllm.distributed.parallel_state.

the variable is used outside the module where it's defined

I actually want to merge parallel_state and communication_op . communication_op needs to access many private variable in parallel_state . That's another story though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we use a get method then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a get method is doable. however, if the majority callers of these getter method come from communication_op, it's better to simply merge two files, imo.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe yes. But if we do not merge the files in this PR, I think using the getter method is a good idea. The current code has two problems: 1) It must be lazily imported for the reason you mentioned above, and 2) the variable name starts with _. I believe using the getter method is a simple solution to the two problems.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want this PR to be super complicated. I will merge the files in the next days, let's don't add new user-facing getter functions just for several days and remove them later 🤣

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I think in this case we should add the getter function even if there's a change that it can be deleted just after several days. This is because we don't have a concrete timeline for the next PR you are thinking of. It can be delayed or de-prioritized any way. I believe adding the getter function doesn't hurt us.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 70a7e26 .


# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:
return input_
out = custom_all_reduce(input_)
if out is not None:
return out
if is_pynccl_enabled_for_all_reduce():
pynccl_utils.all_reduce(input_)
if _TP_PYNCCL_COMMUNICATOR is not None and \
not _TP_PYNCCL_COMMUNICATOR.disabled:
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
_TP_PYNCCL_COMMUNICATOR.all_reduce(input_)
else:
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
Expand Down
Loading
Loading