From 16a1cc9bb2b4bba82d78f329e5a89b44a5523ac8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 4 Aug 2024 11:31:51 -0700 Subject: [PATCH] [misc][distributed] improve libcudart.so finding (#7127) --- .../device_communicators/cuda_wrapper.py | 44 +++++++++---------- .../custom_all_reduce_utils.py | 4 +- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/vllm/distributed/device_communicators/cuda_wrapper.py b/vllm/distributed/device_communicators/cuda_wrapper.py index 5cac3c1d57bc..9c7f41a1f9d6 100644 --- a/vllm/distributed/device_communicators/cuda_wrapper.py +++ b/vllm/distributed/device_communicators/cuda_wrapper.py @@ -4,9 +4,6 @@ """ import ctypes -import glob -import os -import sys from dataclasses import dataclass from typing import Any, Dict, List, Optional @@ -36,24 +33,25 @@ class Function: argtypes: List[Any] -def get_pytorch_default_cudart_library_path() -> str: - # code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa - lib_folder = "cuda_runtime" - lib_name = "libcudart.so.*[0-9]" - lib_path = None - for path in sys.path: - nvidia_path = os.path.join(path, "nvidia") - if not os.path.exists(nvidia_path): - continue - candidate_lib_paths = glob.glob( - os.path.join(nvidia_path, lib_folder, "lib", lib_name)) - if candidate_lib_paths and not lib_path: - lib_path = candidate_lib_paths[0] - if lib_path: - break - if not lib_path: - raise ValueError(f"{lib_name} not found in the system path {sys.path}") - return lib_path +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found = False + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found = True + break + if not found: + # the library is not loaded in the current process + return None + start = line.index("/") + path = line[start:].strip() + return path class CudaRTLibrary: @@ -100,7 +98,9 @@ class CudaRTLibrary: def __init__(self, so_file: Optional[str] = None): if so_file is None: - so_file = get_pytorch_default_cudart_library_path() + so_file = find_loaded_library("libcudart.so") + assert so_file is not None, \ + "libcudart.so is not loaded in the current process" if so_file not in CudaRTLibrary.path_to_library_cache: lib = ctypes.CDLL(so_file) CudaRTLibrary.path_to_library_cache[so_file] = lib diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py index d27d7ee9a249..37ae94c671e3 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py @@ -145,6 +145,7 @@ def can_actually_p2p( p_tgt.start() p_src.join() p_tgt.join() + assert p_src.exitcode == 0 and p_tgt.exitcode == 0 result: List[bool] = [] for src, tgt in zip(batch_src, batch_tgt): a = result_queue.get() @@ -221,7 +222,8 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: # wrap raised exception to provide more information raise RuntimeError( f"Error happened when batch testing " - f"peer-to-peer access from {batch_src} to {batch_tgt}") from e + f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" + f"{returned.stderr.decode()}") from e result = pickle.loads(returned.stdout) for _i, _j, r in zip(batch_src, batch_tgt, result): cache[f"{_i}->{_j}"] = r