diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index d3f9a0ab00f1..9af5fd3e336d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -390,6 +390,12 @@ def __init__(self, # Lazy initialized fields self._request_tracker: RequestTracker + def shutdown_background_loop(self) -> None: + if self._background_loop_unshielded is not None: + self._background_loop_unshielded.cancel() + self._background_loop_unshielded = None + self.background_loop = None + @classmethod def _get_executor_cls( cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1efe2206abe8..3ba8cf8b250f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -338,6 +338,13 @@ def __init__( "vllm.llm_engine", self.observability_config.otlp_traces_endpoint) + tokenizer_group = self.get_tokenizer_group() + + # Ensure that the function doesn't contain a reference to self, + # to avoid engine GC issues + def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: + return tokenizer_group.get_lora_tokenizer(sequence.lora_request) + # Create sequence output processor, e.g. for beam search or # speculative decoding. self.output_processor = ( @@ -346,10 +353,10 @@ def __init__( self.detokenizer, self.scheduler, self.seq_counter, - self.get_tokenizer_for_seq, + get_tokenizer_for_seq, stop_checker=StopChecker( self.scheduler_config.max_model_len, - self.get_tokenizer_for_seq, + get_tokenizer_for_seq, ), )) @@ -481,10 +488,6 @@ def get_tokenizer( ) -> AnyTokenizer: return self.get_tokenizer_group().get_lora_tokenizer(lora_request) - def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: - return self.get_tokenizer_group().get_lora_tokenizer( - sequence.lora_request) - def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: init_kwargs = dict( tokenizer_id=self.model_config.tokenizer, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0fe4dd245b5e..00fbb6fe57ed 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1,4 +1,5 @@ import asyncio +import gc import importlib import inspect import re @@ -71,6 +72,8 @@ async def _force_log(): yield + engine.shutdown_background_loop() + router = APIRouter() @@ -308,28 +311,38 @@ async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - server = await build_server( - args, - llm_engine, - **uvicorn_kwargs, - ) - - loop = asyncio.get_running_loop() - - server_task = loop.create_task(server.serve()) - - def signal_handler() -> None: - # prevents the uvicorn signal handler to exit early - server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - try: - await server_task - except asyncio.CancelledError: - print("Gracefully stopping http server") - await server.shutdown() + server = await build_server( + args, + llm_engine, + **uvicorn_kwargs, + ) + + loop = asyncio.get_running_loop() + + server_task = loop.create_task(server.serve()) + + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + except asyncio.CancelledError: + print("Gracefully stopping http server") + await server.shutdown() + finally: + # Clean up globals + for var in ("openai_serving_chat", "openai_serving_completion", + "openai_serving_embedding", "openai_serving_tokenization", + "engine_args", "engine"): + globals().pop(var, None) + + # This is required for the LLMEngine destructor to run + gc.collect() if __name__ == "__main__": diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index e1e92958e667..5aa8a2d8ab75 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -1,8 +1,5 @@ import asyncio import os -import signal -import threading -import weakref from functools import partial from typing import Any, List, Optional @@ -118,23 +115,8 @@ def _init_executor(self) -> None: self.non_driver_workers.append(worker) self.worker_monitor = WorkerMonitor(self.workers, result_handler) - result_handler.start() self.worker_monitor.start() - # Set up signal handlers to shutdown the executor cleanly - # sometimes gc does not work well - - # Use weakref to avoid holding a reference to self - ref = weakref.ref(self) - - def shutdown(signum, frame): - if executor := ref(): - executor.shutdown() - - if threading.current_thread() is threading.main_thread(): - signal.signal(signal.SIGINT, shutdown) - signal.signal(signal.SIGTERM, shutdown) - self.driver_worker = self._create_worker( distributed_init_method=distributed_init_method) self._run_workers("init_device") diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index 28c8e8699f08..3becb494d4ed 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -5,6 +5,7 @@ import threading import traceback import uuid +import weakref from dataclasses import dataclass from multiprocessing import Queue from multiprocessing.connection import wait @@ -76,7 +77,7 @@ class ResultHandler(threading.Thread): """Handle results from all workers (in background thread)""" def __init__(self) -> None: - super().__init__(daemon=True) + super().__init__(daemon=False) self.result_queue = mp.Queue() self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} @@ -100,27 +101,51 @@ class WorkerMonitor(threading.Thread): def __init__(self, workers: List['ProcessWorkerWrapper'], result_handler: ResultHandler): - super().__init__(daemon=True) + super().__init__(daemon=False) self.workers = workers self.result_handler = result_handler self._close = False + # Set up a handler to ensure that the threads and worker + # processes are shut down in the case the interpreter exits due + # to an unhandled exception. GC does not appear to be reliable + # for this. + ref = weakref.ref(self) + old_handler = sys.excepthook + + def handler(*args): + old_handler(*args) + if (monitor := ref()) is not None: + monitor.close() + + sys.excepthook = handler + def run(self) -> None: + # We are responsible for starting the result handler thread + self.result_handler.start() + # Blocks until any worker exits dead_sentinels = wait([w.process.sentinel for w in self.workers]) if not self._close: self._close = True - # Kill / cleanup all workers - for worker in self.workers: - process = worker.process - if process.sentinel in dead_sentinels: - process.join(JOIN_TIMEOUT_S) - if process.exitcode is not None and process.exitcode != 0: - logger.error("Worker %s pid %s died, exit code: %s", - process.name, process.pid, process.exitcode) + if not sys.is_finalizing(): + # Kill / cleanup all workers + died_count = 0 + for worker in self.workers: + process = worker.process + if process.sentinel in dead_sentinels: + process.join(JOIN_TIMEOUT_S) + if process.exitcode is not None and process.exitcode != 0: + died_count += 1 + logger.error("Worker %s pid %s died, exit code: %s", + process.name, process.pid, + process.exitcode) + if died_count < len(self.workers): + logger.info( + "Killing remaining local vLLM worker processes") + # Cleanup any remaining workers - logger.info("Killing local vLLM worker processes") for worker in self.workers: worker.kill_worker() # Must be done after worker task queues are all closed