diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 67b9e5cf5cc0..b8f8eea44573 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -79,3 +79,6 @@ async def do_log_stats( model_output: Optional[List[SamplerOutput]] = None, ) -> None: pass + + async def check_health(self) -> None: + """Raise if unhealthy""" diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 0f05b59cb2e9..7187bcdbe77b 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -8,6 +8,7 @@ from vllm.sampling_params import SamplingParams VLLM_RPC_SUCCESS_STR = "SUCCESS" +VLLM_RPC_HEALTHY_STR = "HEALTHY" @dataclass @@ -29,6 +30,7 @@ class RPCUtilityRequest(Enum): IS_SERVER_READY = 1 GET_MODEL_CONFIG = 2 DO_LOG_STATS = 3 + CHECK_HEALTH = 4 RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 1e8a98d6418f..bf07a05feb9c 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -6,6 +6,7 @@ from vllm.config import DecodingConfig, ModelConfig from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, + VLLM_RPC_HEALTHY_STR, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCGenerateRequest, RPCUtilityRequest) from vllm.inputs import PromptInputs @@ -153,3 +154,26 @@ async def generate( yield request_output socket.close() + + async def check_health(self) -> None: + """Raise if unhealthy""" + + # Connect to socket. + socket = self.context.socket(zmq.constants.DEALER) + socket.connect(self.path) + + # Ping RPCServer with CHECK_HEALTH request. + await socket.send(pickle.dumps(RPCUtilityRequest.CHECK_HEALTH)) + + # Await the reply from the server. + # TODO: do we need an internal timeout here? + # Or do we expect the external probe to timeout and let this chill? + health_message = pickle.loads(await socket.recv()) + socket.close() + + if isinstance(health_message, Exception): + raise health_message + + if health_message != VLLM_RPC_HEALTHY_STR: + raise ValueError("Expected healthy response from backend but got " + "f{health_message}") diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 6385eaa1b226..7e936b48b030 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -8,7 +8,8 @@ from typing_extensions import Never from vllm import AsyncEngineArgs, AsyncLLMEngine -from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, RPCAbortRequest, +from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCGenerateRequest, RPCUtilityRequest) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext @@ -96,6 +97,17 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): self.socket.send_multipart( [identity, pickle.dumps(e, pickle.HIGHEST_PROTOCOL)]) + async def check_health(self, identity): + try: + await self.engine.check_health() + await self.socket.send_multipart([ + identity, + pickle.dumps(VLLM_RPC_HEALTHY_STR, pickle.HIGHEST_PROTOCOL) + ]) + except Exception as e: + await self.socket.send_multipart( + [identity, pickle.dumps(e, pickle.HIGHEST_PROTOCOL)]) + def _make_handler_coro(self, identity, message) -> Coroutine[Any, Any, Never]: """Route the zmq message to the handler coroutine.""" @@ -115,6 +127,8 @@ def _make_handler_coro(self, identity, return self.do_log_stats(identity) elif request == RPCUtilityRequest.IS_SERVER_READY: return self.is_server_ready(identity) + elif request == RPCUtilityRequest.CHECK_HEALTH: + return self.check_health(identity) else: raise ValueError(f"Unknown RPCUtilityRequest type: {request}")