Skip to content

Commit

Permalink
[BugFix] Fix frontend multiprocessing hang (vllm-project#7217)
Browse files Browse the repository at this point in the history
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
  • Loading branch information
2 people authored and kylesayrs committed Aug 17, 2024
1 parent c13cdd5 commit df8378a
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 5 deletions.
35 changes: 35 additions & 0 deletions tests/entrypoints/openai/test_mp_crash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any

import pytest

from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import FlexibleArgumentParser


def crashing_from_engine_args(
cls,
engine_args: Any = None,
start_engine_loop: Any = None,
usage_context: Any = None,
stat_loggers: Any = None,
) -> "AsyncLLMEngine":
raise Exception("foo")


@pytest.mark.asyncio
async def test_mp_crash_detection(monkeypatch):

with pytest.raises(RuntimeError) as excinfo, monkeypatch.context() as m:
m.setattr(AsyncLLMEngine, "from_engine_args",
crashing_from_engine_args)
parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args([])

async with build_async_engine_client(args):
pass
assert "The server process died before responding to the readiness probe"\
in str(excinfo.value)
11 changes: 10 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,18 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:

# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(rpc_path)
await async_engine_client.setup()

try:
while True:
try:
await async_engine_client.setup()
break
except TimeoutError as e:
if not rpc_server_process.is_alive():
raise RuntimeError(
"The server process died before "
"responding to the readiness probe") from e

yield async_engine_client
finally:
# Ensure rpc server process was terminated
Expand Down
26 changes: 22 additions & 4 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs

# Time to wait before checking it the server process is alive.
SERVER_START_TIMEOUT_MS = 1000


class AsyncEngineRPCClient:

Expand Down Expand Up @@ -61,7 +64,16 @@ def socket(self):
socket.connect(self.rpc_path)
yield socket
finally:
socket.close()
# linger == 0 means discard unsent messages
# when the socket is closed. This is necessary
# because otherwise self.context.destroy() will
# wait for 30 seconds until unsent messages are
# received, which is impossible if the server
# crashed. In the absence of a server crash we
# always expect a response before closing the
# socket anyway.
# Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24
socket.close(linger=0)

async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
expected_type: Any,
Expand All @@ -85,14 +97,19 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,

return data

async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
error_message: str):
async def _send_one_way_rpc_request(self,
request: RPC_REQUEST_TYPE,
error_message: str,
timeout: Optional[int] = None):
"""Send one-way RPC request to trigger an action."""
with self.socket() as socket:
# Ping RPC Server with request.
await socket.send(cloudpickle.dumps(request))

# Await acknowledgement from RPCServer.
if timeout is not None and await socket.poll(timeout=timeout) == 0:
raise TimeoutError(f"server didn't reply within {timeout} ms")

response = cloudpickle.loads(await socket.recv())

if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
Expand All @@ -117,7 +134,8 @@ async def wait_for_server(self):

await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server.")
error_message="Unable to start RPC Server.",
timeout=SERVER_START_TIMEOUT_MS)

async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""
Expand Down

0 comments on commit df8378a

Please sign in to comment.