Skip to content

Commit

Permalink
[Frontend] Minor optimizations to zmq decoupled front-end (vllm-proje…
Browse files Browse the repository at this point in the history
…ct#7957)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>
  • Loading branch information
2 people authored and Jeffwan committed Sep 19, 2024
1 parent e55c56a commit 31d3f9e
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 65 deletions.
81 changes: 37 additions & 44 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
import pickle
from contextlib import contextmanager, suppress
from typing import Any, AsyncGenerator, Mapping, Optional
from typing import Any, AsyncGenerator, Iterator, Mapping, Optional
from uuid import uuid4

import cloudpickle
import zmq
import zmq.asyncio
from zmq.asyncio import Socket

from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
Expand Down Expand Up @@ -115,18 +117,21 @@ def __init__(self, rpc_path: str):
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)

# IPC connection to RPC Server (uses unix sockets).
self.to_rpc_server = self.context.socket(zmq.constants.DEALER)
self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER)
self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.to_rpc_server.bind(rpc_path)

# In process proxy to RPC Server (uses memory-based messaging).
self.from_api_server = self.context.socket(zmq.constants.ROUTER)
self.from_api_server: Socket = self.context.socket(
zmq.constants.ROUTER)
self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.from_api_server.bind(INPROC_PROXY_PATH)

# Asyncio background task for the proxy.
self.proxy_task = asyncio.create_task(
self.proxy_in_task = asyncio.create_task(
self.run_proxy(self.from_api_server, self.to_rpc_server))
self.proxy_out_task = asyncio.create_task(
self.run_proxy(self.to_rpc_server, self.from_api_server))

# Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in vLLM w. frontend
Expand All @@ -136,20 +141,11 @@ def __init__(self, rpc_path: str):
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self.limit_concurrency = socket_limit // 2 - 2

async def run_proxy(self, socket_from, socket_to):
async def run_proxy(self, socket_from: Socket, socket_to: Socket):
"""Background task that runs a proxy"""
poller = zmq.asyncio.Poller()
poller.register(socket_from, zmq.constants.POLLIN)
poller.register(socket_to, zmq.constants.POLLIN)
while True:
events_lst = await poller.poll()
events = dict(events_lst)
if socket_from in events:
identity, msg = await socket_from.recv_multipart()
await socket_to.send_multipart([identity, msg])
if socket_to in events:
identity, msg = await socket_to.recv_multipart()
await socket_from.send_multipart([identity, msg])
frames = await socket_from.recv_multipart(copy=False)
await socket_to.send_multipart(frames, copy=False)

async def setup(self):
"""Setup the client before it starts sending server requests."""
Expand Down Expand Up @@ -180,7 +176,7 @@ def close(self):
self.context.destroy()

@contextmanager
def to_proxy_socket(self):
def to_proxy_socket(self) -> Iterator[Socket]:
# Connect to the RPCServer via the proxy.

# Raise a sensible error if the client was already closed.
Expand Down Expand Up @@ -208,15 +204,17 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,

with self.to_proxy_socket() as socket:
# Ping RPCServer with a request.
await socket.send_multipart([cloudpickle.dumps(request)])
await socket.send_multipart((cloudpickle.dumps(request), ),
copy=False)

# Make sure the server responds
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")

# Await the data from the Server.
data = cloudpickle.loads(await socket.recv())
frame = await socket.recv(copy=False)
data = pickle.loads(frame.buffer)

if isinstance(data, Exception):
# Re-raise exceptions returned by the server
Expand All @@ -234,23 +232,22 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,

return data

async def _send_one_way_rpc_request(
self,
request: RPC_REQUEST_TYPE,
error_message: str,
socket: Optional[zmq.asyncio.Socket] = None):
async def _send_one_way_rpc_request(self,
request: RPC_REQUEST_TYPE,
error_message: str,
socket: Optional[Socket] = None):
"""Send one-way RPC request to trigger an action."""

async def do_rpc_call(socket: zmq.asyncio.Socket,
request: RPC_REQUEST_TYPE):
async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE):

await socket.send_multipart([cloudpickle.dumps(request)])
await socket.send_multipart((cloudpickle.dumps(request), ))

if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")

return cloudpickle.loads(await socket.recv())
frame = await socket.recv(copy=False)
return pickle.loads(frame.buffer)

# Make a new socket connection.
if socket is None:
Expand Down Expand Up @@ -386,21 +383,19 @@ async def generate(
try:
with self.to_proxy_socket() as socket:
# Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart([
cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request))
])
await socket.send_multipart((cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)), ))

# Stream back the results from the RPC Server.
while not finished:
message = await socket.recv()
request_output = cloudpickle.loads(message)
message = await socket.recv(copy=False)
request_output = pickle.loads(message.buffer)

if isinstance(request_output, Exception):
# On exception, check if the server is still healthy
Expand All @@ -424,9 +419,7 @@ async def generate(
if not finished and not self._errored:
await self.abort(request_id)

async def check_health(self,
socket: Optional[zmq.asyncio.Socket] = None
) -> None:
async def check_health(self, socket: Optional[Socket] = None) -> None:
"""Raise if unhealthy"""

await self._send_one_way_rpc_request(
Expand All @@ -451,4 +444,4 @@ async def stop_profile(self) -> None:

await self._send_one_way_rpc_request(
request=RPCUtilityRequest.STOP_PROFILE,
error_message="RPCRequest STOP_PROFILE failed.")
error_message="RPCRequest STOP_PROFILE failed.")
48 changes: 27 additions & 21 deletions vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import pickle
import signal
from typing import Any, Coroutine, Union

Expand All @@ -7,6 +8,8 @@
import zmq
import zmq.asyncio
from typing_extensions import Never
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket

from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
Expand Down Expand Up @@ -35,7 +38,7 @@ def __init__(self, async_engine_args: AsyncEngineArgs,
self.context = zmq.asyncio.Context()

# Init socket.
self.socket = self.context.socket(zmq.constants.DEALER)
self.socket: Socket = self.context.socket(zmq.constants.DEALER)
self.socket.set_hwm(VLLM_RPC_ZMQ_HWM)
self.socket.connect(rpc_path)

Expand Down Expand Up @@ -63,30 +66,31 @@ async def get_config(self, identity, request):
else:
raise ValueError("Unknown Config Request: %s", request)

await self.socket.send_multipart(
[identity, cloudpickle.dumps(config)])
await self.socket.send_multipart((identity, pickle.dumps(config)),
copy=False)

except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)

async def is_tracing_enabled(self, identity):
"""Send the is_tracing_enabled flag"""
tracing_flag = await self.engine.is_tracing_enabled()

await self.socket.send_multipart(
[identity, cloudpickle.dumps(tracing_flag)])
(identity, pickle.dumps(tracing_flag)))

async def do_log_stats(self, identity):
"""Log stats and confirm success."""
await self.engine.do_log_stats()

await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))

async def is_server_ready(self, identity):
"""Notify the client that we are ready."""
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))

async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
Expand All @@ -96,7 +100,7 @@ async def abort(self, identity, request: RPCAbortRequest):
result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR
except Exception as e:
result = e
await self.socket.send_multipart([identity, cloudpickle.dumps(result)])
await self.socket.send_multipart((identity, pickle.dumps(result)))

async def generate(self, identity, generate_request: RPCGenerateRequest):
try:
Expand All @@ -110,45 +114,47 @@ async def generate(self, identity, generate_request: RPCGenerateRequest):

async for request_output in results_generator:
await self.socket.send_multipart(
[identity, cloudpickle.dumps(request_output)])
(identity, pickle.dumps(request_output)), copy=False)

except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)

async def check_health(self, identity):
try:
await self.engine.check_health()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))

except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)

async def start_profile(self, identity):
logger.info("Starting profiler...")
await self.engine.start_profile()
logger.info("Profiler started.")

await self.socket.send_multipart([
await self.socket.send_multipart((
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
pickle.dumps(VLLM_RPC_SUCCESS_STR),
))

async def stop_profile(self, identity):
logger.info("Stopping profiler...")
await self.engine.stop_profile()
logger.info("Profiler stopped.")

await self.socket.send_multipart([
await self.socket.send_multipart((
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
pickle.dumps(VLLM_RPC_SUCCESS_STR),
))

def _make_handler_coro(self, identity,
message) -> Coroutine[Any, Any, Never]:
message: Frame) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine."""

request = cloudpickle.loads(message)
request = cloudpickle.loads(message.buffer)

if isinstance(request, RPCGenerateRequest):
return self.generate(identity, request)
Expand Down Expand Up @@ -189,7 +195,7 @@ async def run_server_loop(self):
running_tasks = set()
while True:
# Wait for a request.
identity, message = await self.socket.recv_multipart()
identity, message = await self.socket.recv_multipart(copy=False)

# Process the request async.
task = asyncio.create_task(
Expand Down

0 comments on commit 31d3f9e

Please sign in to comment.