Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Use MQLLMEngine for embeddings models too #8584

Merged
merged 1 commit into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from enum import Enum
from typing import List, Mapping, Optional, Union

from vllm import PoolingParams
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
Expand All @@ -21,9 +22,9 @@ class MQEngineDeadError(RuntimeError):


@dataclass
class RPCGenerateRequest:
class RPCProcessRequest:
inputs: PromptInputs
sampling_params: SamplingParams
params: Union[SamplingParams, PoolingParams]
request_id: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
Expand Down Expand Up @@ -55,7 +56,7 @@ class RPCStartupResponse:
tracing_enabled: bool


RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest,
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest,
RPCStartupRequest]

REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
Expand Down
106 changes: 74 additions & 32 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket

from vllm import PoolingParams
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block
Expand All @@ -19,8 +20,8 @@
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCGenerateRequest,
RPCHealthRequest, RPCStartupRequest,
RPCError, RPCHealthRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
Expand Down Expand Up @@ -111,20 +112,8 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig):

@staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs):
if engine_args.pipeline_parallel_size > 1:
return True

is_embedding = ModelConfig(
model=engine_args.model,
revision=engine_args.revision,
tokenizer=engine_args.model,
tokenizer_mode="auto",
trust_remote_code=engine_args.trust_remote_code,
quantization=engine_args.quantization,
seed=0,
dtype="auto").embedding_mode

return is_embedding
# Pipeline parallel not yet supported
return engine_args.pipeline_parallel_size > 1

@contextmanager
def get_data_socket(self) -> Iterator[Socket]:
Expand Down Expand Up @@ -382,12 +371,9 @@ def errored(self) -> bool:

@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
return ENGINE_DEAD_ERROR(self._errored_with)
else:
return ENGINE_DEAD_ERROR()
return ENGINE_DEAD_ERROR(self._errored_with)

async def generate(
def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
Expand All @@ -396,6 +382,67 @@ async def generate(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.

Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.

Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
"""
return self._process_request(inputs, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request)

def encode(
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.

Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.

Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.

Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
"""
return self._process_request(inputs, pooling_params, request_id,
lora_request, trace_headers)

async def _process_request(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this not be async?

Copy link
Member Author

@njhill njhill Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

async with yield actually means the method returns an async generator, not a coroutine.. so the generate method still behaves exactly as it did before (tests would fail otherwise).

self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""

# If already dead, error out.
Expand All @@ -410,19 +457,19 @@ async def generate(
try:
# 2) Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower)
if sampling_params.logits_processors:
if isinstance(params, SamplingParams) and params.logits_processors:
# Defensive shallow copy
sampling_params = copy.copy(sampling_params)
logits_processors = sampling_params.logits_processors
sampling_params.logits_processors = None
params = copy.copy(params)
logits_processors = params.logits_processors
params.logits_processors = None
lp_bytes = cloudpickle.dumps(logits_processors)
else:
lp_bytes = None

request_bytes = pickle.dumps(
RPCGenerateRequest(
RPCProcessRequest(
inputs=inputs,
sampling_params=sampling_params,
params=params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
Expand Down Expand Up @@ -452,8 +499,3 @@ async def generate(
await self.abort(request_id)
finally:
self.output_queues.pop(request_id)

async def encode(self, *args,
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")
23 changes: 12 additions & 11 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import cloudpickle
import zmq

from vllm import AsyncEngineArgs, LLMEngine
from vllm import AsyncEngineArgs, LLMEngine, SamplingParams
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf conflicts with isort for this block
Expand All @@ -15,8 +15,8 @@
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCGenerateRequest,
RPCHealthRequest, RPCStartupRequest,
RPCError, RPCHealthRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse)
# yapf: enable
from vllm.logger import init_logger
Expand All @@ -39,8 +39,8 @@ class MQLLMEngine:
in concurrnet manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc.

The :class:`LLMEngine.generate` is kicked off when a new
RPCGenerateRequest is received by the input_socket.
The :class:`LLMEngine` generate or encode process is kicked off when a new
RPCProcessRequest is received by the input_socket.

The self.engine_loop checks the input_socket for new requests,
adds them to the LLMEngine if there are any, calls the internal
Expand Down Expand Up @@ -213,12 +213,13 @@ def handle_new_input(self):
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)

if isinstance(request, RPCGenerateRequest):
if isinstance(request, RPCProcessRequest):
if len(frames) > 1:
# Use cloudpickle for logits processors
assert isinstance(request.params, SamplingParams)
lprocs = cloudpickle.loads(frames[1].buffer)
request.sampling_params.logits_processors = lprocs
self._handle_generate_request(request)
request.params.logits_processors = lprocs
self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
elif isinstance(request, RPCHealthRequest):
Expand All @@ -231,8 +232,8 @@ def handle_new_input(self):
self._send_unhealthy(e)
raise e

def _handle_generate_request(self, request: RPCGenerateRequest):
"""Handle RPCGenerateRequest by adding it to the LLMEngine."""
def _handle_process_request(self, request: RPCProcessRequest):
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
request_id = request.request_id

if self._errored_with is not None:
Expand All @@ -245,7 +246,7 @@ def _handle_generate_request(self, request: RPCGenerateRequest):
self.engine.add_request(
request_id=request_id,
inputs=request.inputs,
params=request.sampling_params,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request)
Expand Down
Loading