diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 5ff39ddfbf99..f632caba9017 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -9,7 +9,7 @@ import pytest -from ..utils import compare_two_settings +from ..utils import compare_two_settings, fork_new_process_for_each_test VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" @@ -28,6 +28,7 @@ (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), ]) +@fork_new_process_for_each_test def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, DIST_BACKEND): if VLLM_MULTI_NODE and DIST_BACKEND == "mp": @@ -77,6 +78,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, "FLASH_ATTN", "FLASHINFER", ]) +@fork_new_process_for_each_test def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND): cudagraph_args = [ # use half precision for speed and memory savings in CI environment diff --git a/tests/utils.py b/tests/utils.py index 1086591464d4..f3ee801ee774 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,6 @@ +import functools import os +import signal import subprocess import sys import time @@ -336,3 +338,40 @@ def wait_for_gpu_memory_to_clear(devices: List[int], f'{dur_s=:.02f} ({threshold_bytes/2**30=})') time.sleep(5) + + +def fork_new_process_for_each_test(f): + + @functools.wraps(f) + def wrapper(*args, **kwargs): + # Make the process the leader of its own process group + # to avoid sending SIGTERM to the parent process + os.setpgrp() + from _pytest.outcomes import Skipped + pid = os.fork() + if pid == 0: + try: + f(*args, **kwargs) + except Skipped as e: + # convert Skipped to exit code 0 + print(str(e)) + os._exit(0) + except Exception: + import traceback + traceback.print_exc() + os._exit(1) + else: + os._exit(0) + else: + pgid = os.getpgid(pid) + _pid, _exitcode = os.waitpid(pid, 0) + # ignore SIGTERM signal itself + old_singla_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) + # kill all child processes + os.killpg(pgid, signal.SIGTERM) + # restore the signal handler + signal.signal(signal.SIGTERM, old_singla_handler) + assert _exitcode == 0, (f"function {f} failed when called with" + f" args {args} and kwargs {kwargs}") + + return wrapper diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index eebdf7bf644d..79081c04ddc1 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -3,7 +3,7 @@ from typing import List, Optional try: - from ray.exceptions import ActorDiedError + from ray.exceptions import ActorDiedError # type: ignore except ImportError: # For older versions of Ray from ray.exceptions import RayActorError as ActorDiedError # type: ignore diff --git a/vllm/utils.py b/vllm/utils.py index 38e1782a51ab..358788c95f30 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -928,7 +928,8 @@ def error_on_invalid_device_count_status(): with contextlib.suppress(Exception): # future pytorch will fix the issue, device_count will not be cached # at that time, `.cache_info().currsize` will error out - cache_entries = torch.cuda.device_count.cache_info().currsize + cache_entries = torch.cuda.device_count.cache_info( # type: ignore + ).currsize if cache_entries != 0: # the function is already called, and the result is cached remembered = torch.cuda.device_count()