Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix frontend multiprocessing hang #7217

Merged
35 changes: 35 additions & 0 deletions tests/entrypoints/openai/test_mp_crash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any

import pytest

from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import FlexibleArgumentParser


def crashing_from_engine_args(
cls,
engine_args: Any = None,
start_engine_loop: Any = None,
usage_context: Any = None,
stat_loggers: Any = None,
) -> "AsyncLLMEngine":
raise Exception("foo")


@pytest.mark.asyncio
async def test_mp_crash_detection(monkeypatch):

with pytest.raises(RuntimeError) as excinfo, monkeypatch.context() as m:
m.setattr(AsyncLLMEngine, "from_engine_args",
crashing_from_engine_args)
parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args([])

async with build_async_engine_client(args):
pass
assert "The server process died before responding to the readiness probe"\
in str(excinfo.value)
11 changes: 10 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,18 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:

# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(rpc_path)
await async_engine_client.setup()

try:
while True:
try:
await async_engine_client.setup()
break
except TimeoutError as e:
if not rpc_server_process.is_alive():
raise RuntimeError(
"The server process died before "
"responding to the readiness probe") from e

yield async_engine_client
finally:
# Ensure rpc server process was terminated
Expand Down
26 changes: 22 additions & 4 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs

# Time to wait before checking it the server process is alive.
SERVER_START_TIMEOUT_MS = 1000


class AsyncEngineRPCClient:

Expand Down Expand Up @@ -61,7 +64,16 @@ def socket(self):
socket.connect(self.rpc_path)
yield socket
finally:
socket.close()
# linger == 0 means discard unsent messages
# when the socket is closed. This is necessary
# because otherwise self.context.destroy() will
# wait for 30 seconds until unsent messages are
# received, which is impossible if the server
# crashed. In the absence of a server crash we
# always expect a response before closing the
# socket anyway.
# Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24
socket.close(linger=0)

async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
expected_type: Any,
Expand All @@ -85,14 +97,19 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,

return data

async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
error_message: str):
async def _send_one_way_rpc_request(self,
request: RPC_REQUEST_TYPE,
error_message: str,
timeout: Optional[int] = None):
"""Send one-way RPC request to trigger an action."""
with self.socket() as socket:
# Ping RPC Server with request.
await socket.send(cloudpickle.dumps(request))

# Await acknowledgement from RPCServer.
if timeout is not None and await socket.poll(timeout=timeout) == 0:
raise TimeoutError(f"server didn't reply within {timeout} ms")

response = cloudpickle.loads(await socket.recv())

if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
Expand All @@ -117,7 +134,8 @@ async def wait_for_server(self):

await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server.")
error_message="Unable to start RPC Server.",
timeout=SERVER_START_TIMEOUT_MS)

async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""
Expand Down
Loading