Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Minor optimizations to zmq decoupled front-end #7957

Merged
merged 2 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 37 additions & 44 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
import pickle
from contextlib import contextmanager, suppress
from typing import Any, AsyncGenerator, Mapping, Optional
from typing import Any, AsyncGenerator, Iterator, Mapping, Optional
from uuid import uuid4

import cloudpickle
import zmq
import zmq.asyncio
from zmq.asyncio import Socket

from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
Expand Down Expand Up @@ -115,18 +117,21 @@ def __init__(self, rpc_path: str):
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)

# IPC connection to RPC Server (uses unix sockets).
self.to_rpc_server = self.context.socket(zmq.constants.DEALER)
self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER)
self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.to_rpc_server.bind(rpc_path)

# In process proxy to RPC Server (uses memory-based messaging).
self.from_api_server = self.context.socket(zmq.constants.ROUTER)
self.from_api_server: Socket = self.context.socket(
zmq.constants.ROUTER)
self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.from_api_server.bind(INPROC_PROXY_PATH)

# Asyncio background task for the proxy.
self.proxy_task = asyncio.create_task(
self.proxy_in_task = asyncio.create_task(
self.run_proxy(self.from_api_server, self.to_rpc_server))
self.proxy_out_task = asyncio.create_task(
self.run_proxy(self.to_rpc_server, self.from_api_server))

# Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in vLLM w. frontend
Expand All @@ -136,20 +141,11 @@ def __init__(self, rpc_path: str):
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self.limit_concurrency = socket_limit // 2 - 2

async def run_proxy(self, socket_from, socket_to):
async def run_proxy(self, socket_from: Socket, socket_to: Socket):
"""Background task that runs a proxy"""
poller = zmq.asyncio.Poller()
poller.register(socket_from, zmq.constants.POLLIN)
poller.register(socket_to, zmq.constants.POLLIN)
while True:
events_lst = await poller.poll()
events = dict(events_lst)
if socket_from in events:
identity, msg = await socket_from.recv_multipart()
await socket_to.send_multipart([identity, msg])
if socket_to in events:
identity, msg = await socket_to.recv_multipart()
await socket_from.send_multipart([identity, msg])
frames = await socket_from.recv_multipart(copy=False)
await socket_to.send_multipart(frames, copy=False)

async def setup(self):
"""Setup the client before it starts sending server requests."""
Expand Down Expand Up @@ -180,7 +176,7 @@ def close(self):
self.context.destroy()

@contextmanager
def to_proxy_socket(self):
def to_proxy_socket(self) -> Iterator[Socket]:
# Connect to the RPCServer via the proxy.

# Raise a sensible error if the client was already closed.
Expand Down Expand Up @@ -208,15 +204,17 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,

with self.to_proxy_socket() as socket:
# Ping RPCServer with a request.
await socket.send_multipart([cloudpickle.dumps(request)])
await socket.send_multipart((cloudpickle.dumps(request), ),
copy=False)

# Make sure the server responds
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")

# Await the data from the Server.
data = cloudpickle.loads(await socket.recv())
frame = await socket.recv(copy=False)
data = pickle.loads(frame.buffer)

if isinstance(data, Exception):
# Re-raise exceptions returned by the server
Expand All @@ -234,23 +232,22 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,

return data

async def _send_one_way_rpc_request(
self,
request: RPC_REQUEST_TYPE,
error_message: str,
socket: Optional[zmq.asyncio.Socket] = None):
async def _send_one_way_rpc_request(self,
request: RPC_REQUEST_TYPE,
error_message: str,
socket: Optional[Socket] = None):
"""Send one-way RPC request to trigger an action."""

async def do_rpc_call(socket: zmq.asyncio.Socket,
request: RPC_REQUEST_TYPE):
async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE):

await socket.send_multipart([cloudpickle.dumps(request)])
await socket.send_multipart((cloudpickle.dumps(request), ))

if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")

return cloudpickle.loads(await socket.recv())
frame = await socket.recv(copy=False)
return pickle.loads(frame.buffer)

# Make a new socket connection.
if socket is None:
Expand Down Expand Up @@ -386,21 +383,19 @@ async def generate(
try:
with self.to_proxy_socket() as socket:
# Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart([
cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request))
])
await socket.send_multipart((cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)), ))

# Stream back the results from the RPC Server.
while not finished:
message = await socket.recv()
request_output = cloudpickle.loads(message)
message = await socket.recv(copy=False)
request_output = pickle.loads(message.buffer)

if isinstance(request_output, Exception):
# On exception, check if the server is still healthy
Expand All @@ -424,9 +419,7 @@ async def generate(
if not finished and not self._errored:
await self.abort(request_id)

async def check_health(self,
socket: Optional[zmq.asyncio.Socket] = None
) -> None:
async def check_health(self, socket: Optional[Socket] = None) -> None:
"""Raise if unhealthy"""

await self._send_one_way_rpc_request(
Expand All @@ -451,4 +444,4 @@ async def stop_profile(self) -> None:

await self._send_one_way_rpc_request(
request=RPCUtilityRequest.STOP_PROFILE,
error_message="RPCRequest STOP_PROFILE failed.")
error_message="RPCRequest STOP_PROFILE failed.")
41 changes: 23 additions & 18 deletions vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import pickle
import signal
from typing import Any, Coroutine, Union

Expand All @@ -7,6 +8,7 @@
import zmq
import zmq.asyncio
from typing_extensions import Never
from zmq.asyncio import Socket

from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
Expand Down Expand Up @@ -35,7 +37,7 @@ def __init__(self, async_engine_args: AsyncEngineArgs,
self.context = zmq.asyncio.Context()

# Init socket.
self.socket = self.context.socket(zmq.constants.DEALER)
self.socket: Socket = self.context.socket(zmq.constants.DEALER)
self.socket.set_hwm(VLLM_RPC_ZMQ_HWM)
self.socket.connect(rpc_path)

Expand Down Expand Up @@ -63,30 +65,31 @@ async def get_config(self, identity, request):
else:
raise ValueError("Unknown Config Request: %s", request)

await self.socket.send_multipart(
[identity, cloudpickle.dumps(config)])
await self.socket.send_multipart((identity, pickle.dumps(config)),
copy=False)

except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)

async def is_tracing_enabled(self, identity):
"""Send the is_tracing_enabled flag"""
tracing_flag = await self.engine.is_tracing_enabled()

await self.socket.send_multipart(
[identity, cloudpickle.dumps(tracing_flag)])
(identity, pickle.dumps(tracing_flag)))

async def do_log_stats(self, identity):
"""Log stats and confirm success."""
await self.engine.do_log_stats()

await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))

async def is_server_ready(self, identity):
"""Notify the client that we are ready."""
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))

async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
Expand All @@ -96,7 +99,7 @@ async def abort(self, identity, request: RPCAbortRequest):
result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR
except Exception as e:
result = e
await self.socket.send_multipart([identity, cloudpickle.dumps(result)])
await self.socket.send_multipart((identity, pickle.dumps(result)))

async def generate(self, identity, generate_request: RPCGenerateRequest):
try:
Expand All @@ -110,39 +113,41 @@ async def generate(self, identity, generate_request: RPCGenerateRequest):

async for request_output in results_generator:
await self.socket.send_multipart(
[identity, cloudpickle.dumps(request_output)])
(identity, pickle.dumps(request_output)), copy=False)

except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)

async def check_health(self, identity):
try:
await self.engine.check_health()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))

except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)

async def start_profile(self, identity):
logger.info("Starting profiler...")
await self.engine.start_profile()
logger.info("Profiler started.")

await self.socket.send_multipart([
await self.socket.send_multipart((
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
pickle.dumps(VLLM_RPC_SUCCESS_STR),
))

async def stop_profile(self, identity):
logger.info("Stopping profiler...")
await self.engine.stop_profile()
logger.info("Profiler stopped.")

await self.socket.send_multipart([
await self.socket.send_multipart((
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
pickle.dumps(VLLM_RPC_SUCCESS_STR),
))

def _make_handler_coro(self, identity,
message) -> Coroutine[Any, Any, Never]:
Expand Down
Loading