Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[distributed][misc] add specialized method for cuda platform #7249

Merged
merged 1 commit into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
gpu_p2p_access_check)
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless, is_full_nvlink
from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless

try:
assert ops.is_custom_op_supported("_C_custom_ar::meta_size")
Expand Down Expand Up @@ -113,7 +114,10 @@ def __init__(self,
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
full_nvlink = is_full_nvlink(physical_device_ids)
assert current_platform.is_cuda()
from vllm.platforms.cuda import CudaPlatform
cuda_platform: CudaPlatform = current_platform
full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)
if world_size > 2 and not full_nvlink:
logger.warning(
"Custom allreduce is disabled because it's not supported on"
Expand Down
37 changes: 36 additions & 1 deletion vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,21 @@

import os
from functools import lru_cache, wraps
from typing import Tuple
from typing import List, Tuple

import pynvml

from vllm.logger import init_logger

from .interface import Platform, PlatformEnum

logger = init_logger(__name__)

# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA


def with_nvml_context(fn):

Expand Down Expand Up @@ -47,3 +56,29 @@ class CudaPlatform(Platform):
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_capability(physical_device_id)

@staticmethod
@with_nvml_context
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [
pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle,
pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
return True
50 changes: 0 additions & 50 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,56 +1034,6 @@ def cuda_device_count_stateless() -> int:
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA

try:
import pynvml
except ImportError:
# For non-NV devices
pynvml = None


def with_nvml_context(fn):

@wraps(fn)
def wrapper(*args, **kwargs):
if pynvml is not None:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
if pynvml is not None:
pynvml.nvmlShutdown()

return wrapper


@with_nvml_context
def is_full_nvlink(device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
return True


#From: https://stackoverflow.com/a/4104188/2749989
def run_once(f):

Expand Down
Loading