Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Improve choice of Python multiprocessing method #8823

Merged
merged 4 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from vllm.sequence import ExecuteModelRequest
from vllm.triton_utils import maybe_set_triton_cache_manager
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async,
cuda_is_initialized, get_distributed_init_method,
get_open_port, get_vllm_instance_id, make_async,
update_environment_variables)

logger = init_logger(__name__)
Expand Down Expand Up @@ -122,6 +122,13 @@ def _check_executor_parameters(self):
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})

if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (
Expand Down
17 changes: 10 additions & 7 deletions vllm/executor/multiproc_worker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@

JOIN_TIMEOUT_S = 2

mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
mp = multiprocessing.get_context(mp_method)


@dataclass
class Result(Generic[T]):
Expand Down Expand Up @@ -77,7 +74,7 @@ class ResultHandler(threading.Thread):

def __init__(self) -> None:
super().__init__(daemon=True)
self.result_queue = mp.Queue()
self.result_queue = get_mp_context().Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}

def run(self):
Expand Down Expand Up @@ -147,10 +144,11 @@ class ProcessWorkerWrapper:

def __init__(self, result_handler: ResultHandler,
worker_factory: Callable[[], Any]) -> None:
self._task_queue = mp.Queue()
self.mp = get_mp_context()
self._task_queue = self.mp.Queue()
self.result_queue = result_handler.result_queue
self.tasks = result_handler.tasks
self.process: BaseProcess = mp.Process( # type: ignore[attr-defined]
self.process: BaseProcess = self.mp.Process( # type: ignore[attr-defined]
target=_run_worker_process,
name="VllmWorkerProcess",
kwargs=dict(
Expand Down Expand Up @@ -204,7 +202,7 @@ def _run_worker_process(
"""Worker process event loop"""

# Add process-specific prefix to stdout and stderr
process_name = mp.current_process().name
process_name = get_mp_context().current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
Expand Down Expand Up @@ -269,3 +267,8 @@ def write_with_prefix(s: str):

file.start_new_line = True # type: ignore[attr-defined]
file.write = write_with_prefix # type: ignore[method-assign]


def get_mp_context():
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)
26 changes: 26 additions & 0 deletions vllm/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser

logger = init_logger(__name__)


def register_signal_handlers():

Expand Down Expand Up @@ -114,7 +117,30 @@ def _add_query_options(
return parser


def env_setup():
# The safest multiprocessing method is `spawn`, as the default `fork` method
# is not compatible with some accelerators. The default method will be
# changing in future versions of Python, so we should use it explicitly when
# possible.
#
# We only set it here in the CLI entrypoint, because changing to `spawn`
# could break some existing code using vLLM as a library. `spawn` will cause
# unexpected behavior if the code is not protected by
# `if __name__ == "__main__":`.
#
# References:
# - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
# - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing
# - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors
# - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders
if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ:
logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"


def main():
env_setup()

parser = FlexibleArgumentParser(description="vLLM CLI")
subparsers = parser.add_subparsers(required=True)

Expand Down
7 changes: 7 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,13 @@ def cuda_device_count_stateless() -> int:
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)


def cuda_is_initialized() -> bool:
"""Check if CUDA is initialized."""
if not torch.cuda._is_compiled():
return False
return torch.cuda.is_initialized()


def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
"""Make an instance method that weakly references
its associated instance and no-ops once that
Expand Down
Loading