From cd230c6f80f779dc4469adfc8d7a5edcc55451c1 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 18 Sep 2024 15:22:01 -0700 Subject: [PATCH] [Frontend] Use MQLLMEngine for embeddings models too --- vllm/engine/multiprocessing/__init__.py | 7 +- vllm/engine/multiprocessing/client.py | 106 +++++++++++++++++------- vllm/engine/multiprocessing/engine.py | 23 ++--- 3 files changed, 90 insertions(+), 46 deletions(-) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index ba5c6e15fc82..700332864d17 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -2,6 +2,7 @@ from enum import Enum from typing import List, Mapping, Optional, Union +from vllm import PoolingParams from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput @@ -21,9 +22,9 @@ class MQEngineDeadError(RuntimeError): @dataclass -class RPCGenerateRequest: +class RPCProcessRequest: inputs: PromptInputs - sampling_params: SamplingParams + params: Union[SamplingParams, PoolingParams] request_id: str lora_request: Optional[LoRARequest] = None trace_headers: Optional[Mapping[str, str]] = None @@ -55,7 +56,7 @@ class RPCStartupResponse: tracing_enabled: bool -RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest, +RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest, RPCStartupRequest] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 2cb4de79131f..aa9dbbd448af 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -11,6 +11,7 @@ from zmq import Frame # type: ignore[attr-defined] from zmq.asyncio import Socket +from vllm import PoolingParams from vllm.config import DecodingConfig, EngineConfig, ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs # yapf conflicts with isort for this block @@ -19,8 +20,8 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCGenerateRequest, - RPCHealthRequest, RPCStartupRequest, + RPCError, RPCHealthRequest, + RPCProcessRequest, RPCStartupRequest, RPCStartupResponse) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT @@ -111,20 +112,8 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig): @staticmethod def is_unsupported_config(engine_args: AsyncEngineArgs): - if engine_args.pipeline_parallel_size > 1: - return True - - is_embedding = ModelConfig( - model=engine_args.model, - revision=engine_args.revision, - tokenizer=engine_args.model, - tokenizer_mode="auto", - trust_remote_code=engine_args.trust_remote_code, - quantization=engine_args.quantization, - seed=0, - dtype="auto").embedding_mode - - return is_embedding + # Pipeline parallel not yet supported + return engine_args.pipeline_parallel_size > 1 @contextmanager def get_data_socket(self) -> Iterator[Socket]: @@ -382,12 +371,9 @@ def errored(self) -> bool: @property def dead_error(self) -> BaseException: - if self._errored_with is not None: - return ENGINE_DEAD_ERROR(self._errored_with) - else: - return ENGINE_DEAD_ERROR() + return ENGINE_DEAD_ERROR(self._errored_with) - async def generate( + def generate( self, inputs: PromptInputs, sampling_params: SamplingParams, @@ -396,6 +382,67 @@ async def generate( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request to use + for generation, if any. + """ + return self._process_request(inputs, sampling_params, request_id, + lora_request, trace_headers, + prompt_adapter_request) + + def encode( + self, + inputs: PromptInputs, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + """Generate outputs for a request from an embedding model. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. + pooling_params: The pooling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + + Yields: + The output `EmbeddingRequestOutput` objects from the LLMEngine + for the request. + """ + return self._process_request(inputs, pooling_params, request_id, + lora_request, trace_headers) + + async def _process_request( + self, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ + EmbeddingRequestOutput, None]]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" # If already dead, error out. @@ -410,19 +457,19 @@ async def generate( try: # 2) Detach logits processors so that they can be pickled # separately (may require cloudpickle which is slower) - if sampling_params.logits_processors: + if isinstance(params, SamplingParams) and params.logits_processors: # Defensive shallow copy - sampling_params = copy.copy(sampling_params) - logits_processors = sampling_params.logits_processors - sampling_params.logits_processors = None + params = copy.copy(params) + logits_processors = params.logits_processors + params.logits_processors = None lp_bytes = cloudpickle.dumps(logits_processors) else: lp_bytes = None request_bytes = pickle.dumps( - RPCGenerateRequest( + RPCProcessRequest( inputs=inputs, - sampling_params=sampling_params, + params=params, request_id=request_id, lora_request=lora_request, trace_headers=trace_headers, @@ -452,8 +499,3 @@ async def generate( await self.abort(request_id) finally: self.output_queues.pop(request_id) - - async def encode(self, *args, - **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: - raise NotImplementedError( - "Embeddings not supported with multiprocessing backend") diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 70cd6e5cb600..f4ca23157085 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -6,7 +6,7 @@ import cloudpickle import zmq -from vllm import AsyncEngineArgs, LLMEngine +from vllm import AsyncEngineArgs, LLMEngine, SamplingParams from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) # yapf conflicts with isort for this block @@ -15,8 +15,8 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCGenerateRequest, - RPCHealthRequest, RPCStartupRequest, + RPCError, RPCHealthRequest, + RPCProcessRequest, RPCStartupRequest, RPCStartupResponse) # yapf: enable from vllm.logger import init_logger @@ -39,8 +39,8 @@ class MQLLMEngine: in concurrnet manner. It runs a background loop and uses zeromq to receive new requests and stream outputs incrementally via ipc. - The :class:`LLMEngine.generate` is kicked off when a new - RPCGenerateRequest is received by the input_socket. + The :class:`LLMEngine` generate or encode process is kicked off when a new + RPCProcessRequest is received by the input_socket. The self.engine_loop checks the input_socket for new requests, adds them to the LLMEngine if there are any, calls the internal @@ -213,12 +213,13 @@ def handle_new_input(self): frames = self.input_socket.recv_multipart(copy=False) request = pickle.loads(frames[0].buffer) - if isinstance(request, RPCGenerateRequest): + if isinstance(request, RPCProcessRequest): if len(frames) > 1: # Use cloudpickle for logits processors + assert isinstance(request.params, SamplingParams) lprocs = cloudpickle.loads(frames[1].buffer) - request.sampling_params.logits_processors = lprocs - self._handle_generate_request(request) + request.params.logits_processors = lprocs + self._handle_process_request(request) elif isinstance(request, RPCAbortRequest): self._handle_abort_request(request) elif isinstance(request, RPCHealthRequest): @@ -231,8 +232,8 @@ def handle_new_input(self): self._send_unhealthy(e) raise e - def _handle_generate_request(self, request: RPCGenerateRequest): - """Handle RPCGenerateRequest by adding it to the LLMEngine.""" + def _handle_process_request(self, request: RPCProcessRequest): + """Handle RPCProcessRequest by adding it to the LLMEngine.""" request_id = request.request_id if self._errored_with is not None: @@ -245,7 +246,7 @@ def _handle_generate_request(self, request: RPCGenerateRequest): self.engine.add_request( request_id=request_id, inputs=request.inputs, - params=request.sampling_params, + params=request.params, lora_request=request.lora_request, trace_headers=request.trace_headers, prompt_adapter_request=request.prompt_adapter_request)