Skip to content

Commit

Permalink
[Frontend] Use MQLLMEngine for embeddings models too (vllm-project#8584)
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored and dtrifiro committed Sep 27, 2024
1 parent 48a4faf commit 38d96ec
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 46 deletions.
7 changes: 4 additions & 3 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from enum import Enum
from typing import List, Mapping, Optional, Union

from vllm import PoolingParams
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
Expand All @@ -21,9 +22,9 @@ class MQEngineDeadError(RuntimeError):


@dataclass
class RPCGenerateRequest:
class RPCProcessRequest:
inputs: PromptInputs
sampling_params: SamplingParams
params: Union[SamplingParams, PoolingParams]
request_id: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
Expand Down Expand Up @@ -55,7 +56,7 @@ class RPCStartupResponse:
tracing_enabled: bool


RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest,
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest,
RPCStartupRequest]

REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
Expand Down
106 changes: 74 additions & 32 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket

from vllm import PoolingParams
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block
Expand All @@ -19,8 +20,8 @@
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCGenerateRequest,
RPCHealthRequest, RPCStartupRequest,
RPCError, RPCHealthRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
Expand Down Expand Up @@ -111,20 +112,8 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig):

@staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs):
if engine_args.pipeline_parallel_size > 1:
return True

is_embedding = ModelConfig(
model=engine_args.model,
revision=engine_args.revision,
tokenizer=engine_args.model,
tokenizer_mode="auto",
trust_remote_code=engine_args.trust_remote_code,
quantization=engine_args.quantization,
seed=0,
dtype="auto").embedding_mode

return is_embedding
# Pipeline parallel not yet supported
return engine_args.pipeline_parallel_size > 1

@contextmanager
def get_data_socket(self) -> Iterator[Socket]:
Expand Down Expand Up @@ -382,12 +371,9 @@ def errored(self) -> bool:

@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
return ENGINE_DEAD_ERROR(self._errored_with)
else:
return ENGINE_DEAD_ERROR()
return ENGINE_DEAD_ERROR(self._errored_with)

async def generate(
def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
Expand All @@ -396,6 +382,67 @@ async def generate(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
"""
return self._process_request(inputs, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request)

def encode(
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
"""
return self._process_request(inputs, pooling_params, request_id,
lora_request, trace_headers)

async def _process_request(
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""

# If already dead, error out.
Expand All @@ -410,19 +457,19 @@ async def generate(
try:
# 2) Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower)
if sampling_params.logits_processors:
if isinstance(params, SamplingParams) and params.logits_processors:
# Defensive shallow copy
sampling_params = copy.copy(sampling_params)
logits_processors = sampling_params.logits_processors
sampling_params.logits_processors = None
params = copy.copy(params)
logits_processors = params.logits_processors
params.logits_processors = None
lp_bytes = cloudpickle.dumps(logits_processors)
else:
lp_bytes = None

request_bytes = pickle.dumps(
RPCGenerateRequest(
RPCProcessRequest(
inputs=inputs,
sampling_params=sampling_params,
params=params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
Expand Down Expand Up @@ -452,8 +499,3 @@ async def generate(
await self.abort(request_id)
finally:
self.output_queues.pop(request_id)

async def encode(self, *args,
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")
23 changes: 12 additions & 11 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import cloudpickle
import zmq

from vllm import AsyncEngineArgs, LLMEngine
from vllm import AsyncEngineArgs, LLMEngine, SamplingParams
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf conflicts with isort for this block
Expand All @@ -15,8 +15,8 @@
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCGenerateRequest,
RPCHealthRequest, RPCStartupRequest,
RPCError, RPCHealthRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse)
# yapf: enable
from vllm.logger import init_logger
Expand All @@ -39,8 +39,8 @@ class MQLLMEngine:
in concurrnet manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc.
The :class:`LLMEngine.generate` is kicked off when a new
RPCGenerateRequest is received by the input_socket.
The :class:`LLMEngine` generate or encode process is kicked off when a new
RPCProcessRequest is received by the input_socket.
The self.engine_loop checks the input_socket for new requests,
adds them to the LLMEngine if there are any, calls the internal
Expand Down Expand Up @@ -213,12 +213,13 @@ def handle_new_input(self):
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)

if isinstance(request, RPCGenerateRequest):
if isinstance(request, RPCProcessRequest):
if len(frames) > 1:
# Use cloudpickle for logits processors
assert isinstance(request.params, SamplingParams)
lprocs = cloudpickle.loads(frames[1].buffer)
request.sampling_params.logits_processors = lprocs
self._handle_generate_request(request)
request.params.logits_processors = lprocs
self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
elif isinstance(request, RPCHealthRequest):
Expand All @@ -231,8 +232,8 @@ def handle_new_input(self):
self._send_unhealthy(e)
raise e

def _handle_generate_request(self, request: RPCGenerateRequest):
"""Handle RPCGenerateRequest by adding it to the LLMEngine."""
def _handle_process_request(self, request: RPCProcessRequest):
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
request_id = request.request_id

if self._errored_with is not None:
Expand All @@ -245,7 +246,7 @@ def _handle_generate_request(self, request: RPCGenerateRequest):
self.engine.add_request(
request_id=request_id,
inputs=request.inputs,
params=request.sampling_params,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request)
Expand Down

0 comments on commit 38d96ec

Please sign in to comment.