diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index a472e12e8ca4..c457555c54b9 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,11 +1,13 @@ import asyncio +import pickle from contextlib import contextmanager, suppress -from typing import Any, AsyncGenerator, Mapping, Optional +from typing import Any, AsyncGenerator, Iterator, Mapping, Optional from uuid import uuid4 import cloudpickle import zmq import zmq.asyncio +from zmq.asyncio import Socket from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -115,18 +117,21 @@ def __init__(self, rpc_path: str): self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) # IPC connection to RPC Server (uses unix sockets). - self.to_rpc_server = self.context.socket(zmq.constants.DEALER) + self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER) self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM) self.to_rpc_server.bind(rpc_path) # In process proxy to RPC Server (uses memory-based messaging). - self.from_api_server = self.context.socket(zmq.constants.ROUTER) + self.from_api_server: Socket = self.context.socket( + zmq.constants.ROUTER) self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM) self.from_api_server.bind(INPROC_PROXY_PATH) # Asyncio background task for the proxy. - self.proxy_task = asyncio.create_task( + self.proxy_in_task = asyncio.create_task( self.run_proxy(self.from_api_server, self.to_rpc_server)) + self.proxy_out_task = asyncio.create_task( + self.run_proxy(self.to_rpc_server, self.from_api_server)) # Since we open 1 inproc socket per request, we have a hard cap on # the number of requests that can run in vLLM w. frontend @@ -136,20 +141,11 @@ def __init__(self, rpc_path: str): # 1 for generate(), 1 for abort(), do_log_stats(), check_health() self.limit_concurrency = socket_limit // 2 - 2 - async def run_proxy(self, socket_from, socket_to): + async def run_proxy(self, socket_from: Socket, socket_to: Socket): """Background task that runs a proxy""" - poller = zmq.asyncio.Poller() - poller.register(socket_from, zmq.constants.POLLIN) - poller.register(socket_to, zmq.constants.POLLIN) while True: - events_lst = await poller.poll() - events = dict(events_lst) - if socket_from in events: - identity, msg = await socket_from.recv_multipart() - await socket_to.send_multipart([identity, msg]) - if socket_to in events: - identity, msg = await socket_to.recv_multipart() - await socket_from.send_multipart([identity, msg]) + frames = await socket_from.recv_multipart(copy=False) + await socket_to.send_multipart(frames, copy=False) async def setup(self): """Setup the client before it starts sending server requests.""" @@ -180,7 +176,7 @@ def close(self): self.context.destroy() @contextmanager - def to_proxy_socket(self): + def to_proxy_socket(self) -> Iterator[Socket]: # Connect to the RPCServer via the proxy. # Raise a sensible error if the client was already closed. @@ -208,7 +204,8 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, with self.to_proxy_socket() as socket: # Ping RPCServer with a request. - await socket.send_multipart([cloudpickle.dumps(request)]) + await socket.send_multipart((cloudpickle.dumps(request), ), + copy=False) # Make sure the server responds if await socket.poll(timeout=self._data_timeout) == 0: @@ -216,7 +213,8 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, f"{self._data_timeout} ms") # Await the data from the Server. - data = cloudpickle.loads(await socket.recv()) + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) if isinstance(data, Exception): # Re-raise exceptions returned by the server @@ -234,23 +232,22 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, return data - async def _send_one_way_rpc_request( - self, - request: RPC_REQUEST_TYPE, - error_message: str, - socket: Optional[zmq.asyncio.Socket] = None): + async def _send_one_way_rpc_request(self, + request: RPC_REQUEST_TYPE, + error_message: str, + socket: Optional[Socket] = None): """Send one-way RPC request to trigger an action.""" - async def do_rpc_call(socket: zmq.asyncio.Socket, - request: RPC_REQUEST_TYPE): + async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): - await socket.send_multipart([cloudpickle.dumps(request)]) + await socket.send_multipart((cloudpickle.dumps(request), )) if await socket.poll(timeout=self._data_timeout) == 0: raise TimeoutError("Server didn't reply within " f"{self._data_timeout} ms") - return cloudpickle.loads(await socket.recv()) + frame = await socket.recv(copy=False) + return pickle.loads(frame.buffer) # Make a new socket connection. if socket is None: @@ -386,21 +383,19 @@ async def generate( try: with self.to_proxy_socket() as socket: # Send RPCGenerateRequest to the RPCServer. - await socket.send_multipart([ - cloudpickle.dumps( - RPCGenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request)) - ]) + await socket.send_multipart((cloudpickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)), )) # Stream back the results from the RPC Server. while not finished: - message = await socket.recv() - request_output = cloudpickle.loads(message) + message = await socket.recv(copy=False) + request_output = pickle.loads(message.buffer) if isinstance(request_output, Exception): # On exception, check if the server is still healthy @@ -424,9 +419,7 @@ async def generate( if not finished and not self._errored: await self.abort(request_id) - async def check_health(self, - socket: Optional[zmq.asyncio.Socket] = None - ) -> None: + async def check_health(self, socket: Optional[Socket] = None) -> None: """Raise if unhealthy""" await self._send_one_way_rpc_request( @@ -451,4 +444,4 @@ async def stop_profile(self) -> None: await self._send_one_way_rpc_request( request=RPCUtilityRequest.STOP_PROFILE, - error_message="RPCRequest STOP_PROFILE failed.") \ No newline at end of file + error_message="RPCRequest STOP_PROFILE failed.") diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 738d12bbef05..bebc2faedb68 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -1,4 +1,5 @@ import asyncio +import pickle import signal from typing import Any, Coroutine, Union @@ -7,6 +8,8 @@ import zmq import zmq.asyncio from typing_extensions import Never +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket from vllm import AsyncEngineArgs, AsyncLLMEngine from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, @@ -35,7 +38,7 @@ def __init__(self, async_engine_args: AsyncEngineArgs, self.context = zmq.asyncio.Context() # Init socket. - self.socket = self.context.socket(zmq.constants.DEALER) + self.socket: Socket = self.context.socket(zmq.constants.DEALER) self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) self.socket.connect(rpc_path) @@ -63,30 +66,31 @@ async def get_config(self, identity, request): else: raise ValueError("Unknown Config Request: %s", request) - await self.socket.send_multipart( - [identity, cloudpickle.dumps(config)]) + await self.socket.send_multipart((identity, pickle.dumps(config)), + copy=False) except Exception as e: - await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + await self.socket.send_multipart((identity, pickle.dumps(e)), + copy=False) async def is_tracing_enabled(self, identity): """Send the is_tracing_enabled flag""" tracing_flag = await self.engine.is_tracing_enabled() await self.socket.send_multipart( - [identity, cloudpickle.dumps(tracing_flag)]) + (identity, pickle.dumps(tracing_flag))) async def do_log_stats(self, identity): """Log stats and confirm success.""" await self.engine.do_log_stats() await self.socket.send_multipart( - [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) + (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) async def is_server_ready(self, identity): """Notify the client that we are ready.""" await self.socket.send_multipart( - [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) + (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) async def abort(self, identity, request: RPCAbortRequest): """Abort request and notify the client of success.""" @@ -96,7 +100,7 @@ async def abort(self, identity, request: RPCAbortRequest): result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR except Exception as e: result = e - await self.socket.send_multipart([identity, cloudpickle.dumps(result)]) + await self.socket.send_multipart((identity, pickle.dumps(result))) async def generate(self, identity, generate_request: RPCGenerateRequest): try: @@ -110,45 +114,47 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): async for request_output in results_generator: await self.socket.send_multipart( - [identity, cloudpickle.dumps(request_output)]) + (identity, pickle.dumps(request_output)), copy=False) except Exception as e: - await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + await self.socket.send_multipart((identity, pickle.dumps(e)), + copy=False) async def check_health(self, identity): try: await self.engine.check_health() await self.socket.send_multipart( - [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) + (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) except Exception as e: - await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + await self.socket.send_multipart((identity, pickle.dumps(e)), + copy=False) async def start_profile(self, identity): logger.info("Starting profiler...") await self.engine.start_profile() logger.info("Profiler started.") - await self.socket.send_multipart([ + await self.socket.send_multipart(( identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) + pickle.dumps(VLLM_RPC_SUCCESS_STR), + )) async def stop_profile(self, identity): logger.info("Stopping profiler...") await self.engine.stop_profile() logger.info("Profiler stopped.") - await self.socket.send_multipart([ + await self.socket.send_multipart(( identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) + pickle.dumps(VLLM_RPC_SUCCESS_STR), + )) def _make_handler_coro(self, identity, - message) -> Coroutine[Any, Any, Never]: + message: Frame) -> Coroutine[Any, Any, Never]: """Route the zmq message to the handler coroutine.""" - request = cloudpickle.loads(message) + request = cloudpickle.loads(message.buffer) if isinstance(request, RPCGenerateRequest): return self.generate(identity, request) @@ -189,7 +195,7 @@ async def run_server_loop(self): running_tasks = set() while True: # Wait for a request. - identity, message = await self.socket.recv_multipart() + identity, message = await self.socket.recv_multipart(copy=False) # Process the request async. task = asyncio.create_task(