diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index a4f30808d32e1..479dc95a8b667 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -11,7 +11,8 @@ gpu_p2p_access_check) from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger -from vllm.utils import cuda_device_count_stateless, is_full_nvlink +from vllm.platforms import current_platform +from vllm.utils import cuda_device_count_stateless try: assert ops.is_custom_op_supported("_C_custom_ar::meta_size") @@ -113,7 +114,10 @@ def __init__(self, # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink - full_nvlink = is_full_nvlink(physical_device_ids) + assert current_platform.is_cuda() + from vllm.platforms.cuda import CudaPlatform + cuda_platform: CudaPlatform = current_platform + full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids) if world_size > 2 and not full_nvlink: logger.warning( "Custom allreduce is disabled because it's not supported on" diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 02ba227460e3f..a7e760cc16408 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -4,12 +4,21 @@ import os from functools import lru_cache, wraps -from typing import Tuple +from typing import List, Tuple import pynvml +from vllm.logger import init_logger + from .interface import Platform, PlatformEnum +logger = init_logger(__name__) + +# NVML utils +# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, +# all the related functions work on real physical device ids. +# the major benefit of using NVML is that it will not initialize CUDA + def with_nvml_context(fn): @@ -47,3 +56,29 @@ class CudaPlatform(Platform): def get_device_capability(device_id: int = 0) -> Tuple[int, int]: physical_device_id = device_id_to_physical_device_id(device_id) return get_physical_device_capability(physical_device_id) + + @staticmethod + @with_nvml_context + def is_full_nvlink(physical_device_ids: List[int]) -> bool: + """ + query if the set of gpus are fully connected by nvlink (1 hop) + """ + handles = [ + pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids + ] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle, peer_handle, + pynvml.NVML_P2P_CAPS_INDEX_NVLINK) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: + return False + except pynvml.NVMLError as error: + logger.error( + "NVLink detection failed. This is normal if your" + " machine has no NVLink equipped.", + exc_info=error) + return False + return True diff --git a/vllm/utils.py b/vllm/utils.py index 61e3bb0bfc333..08aa889b5e447 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1034,56 +1034,6 @@ def cuda_device_count_stateless() -> int: return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) -# NVML utils -# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, -# all the related functions work on real physical device ids. -# the major benefit of using NVML is that it will not initialize CUDA - -try: - import pynvml -except ImportError: - # For non-NV devices - pynvml = None - - -def with_nvml_context(fn): - - @wraps(fn) - def wrapper(*args, **kwargs): - if pynvml is not None: - pynvml.nvmlInit() - try: - return fn(*args, **kwargs) - finally: - if pynvml is not None: - pynvml.nvmlShutdown() - - return wrapper - - -@with_nvml_context -def is_full_nvlink(device_ids: List[int]) -> bool: - """ - query if the set of gpus are fully connected by nvlink (1 hop) - """ - handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids] - for i, handle in enumerate(handles): - for j, peer_handle in enumerate(handles): - if i < j: - try: - p2p_status = pynvml.nvmlDeviceGetP2PStatus( - handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) - if p2p_status != pynvml.NVML_P2P_STATUS_OK: - return False - except pynvml.NVMLError as error: - logger.error( - "NVLink detection failed. This is normal if your" - " machine has no NVLink equipped.", - exc_info=error) - return False - return True - - #From: https://stackoverflow.com/a/4104188/2749989 def run_once(f):