Skip to content

Commit

Permalink
[ci][distributed] try to fix pp test (#7054)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Aug 2, 2024
1 parent 3bb4b1e commit 2523577
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 3 deletions.
4 changes: 3 additions & 1 deletion tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pytest

from ..utils import compare_two_settings
from ..utils import compare_two_settings, fork_new_process_for_each_test

VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"

Expand All @@ -28,6 +28,7 @@
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
])
@fork_new_process_for_each_test
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
DIST_BACKEND):
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
Expand Down Expand Up @@ -77,6 +78,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
"FLASH_ATTN",
"FLASHINFER",
])
@fork_new_process_for_each_test
def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND):
cudagraph_args = [
# use half precision for speed and memory savings in CI environment
Expand Down
39 changes: 39 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import functools
import os
import signal
import subprocess
import sys
import time
Expand Down Expand Up @@ -336,3 +338,40 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')

time.sleep(5)


def fork_new_process_for_each_test(f):

@functools.wraps(f)
def wrapper(*args, **kwargs):
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
os.setpgrp()
from _pytest.outcomes import Skipped
pid = os.fork()
if pid == 0:
try:
f(*args, **kwargs)
except Skipped as e:
# convert Skipped to exit code 0
print(str(e))
os._exit(0)
except Exception:
import traceback
traceback.print_exc()
os._exit(1)
else:
os._exit(0)
else:
pgid = os.getpgid(pid)
_pid, _exitcode = os.waitpid(pid, 0)
# ignore SIGTERM signal itself
old_singla_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
# kill all child processes
os.killpg(pgid, signal.SIGTERM)
# restore the signal handler
signal.signal(signal.SIGTERM, old_singla_handler)
assert _exitcode == 0, (f"function {f} failed when called with"
f" args {args} and kwargs {kwargs}")

return wrapper
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Optional

try:
from ray.exceptions import ActorDiedError
from ray.exceptions import ActorDiedError # type: ignore
except ImportError:
# For older versions of Ray
from ray.exceptions import RayActorError as ActorDiedError # type: ignore
Expand Down
3 changes: 2 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,8 @@ def error_on_invalid_device_count_status():
with contextlib.suppress(Exception):
# future pytorch will fix the issue, device_count will not be cached
# at that time, `.cache_info().currsize` will error out
cache_entries = torch.cuda.device_count.cache_info().currsize
cache_entries = torch.cuda.device_count.cache_info( # type: ignore
).currsize
if cache_entries != 0:
# the function is already called, and the result is cached
remembered = torch.cuda.device_count()
Expand Down

0 comments on commit 2523577

Please sign in to comment.