-
-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] V1 engine implements parallel sampling, 1/2: AsyncLLM support #10980
base: main
Are you sure you want to change the base?
Changes from all commits
57f3329
50584f6
bf3cfd0
98726ed
cd649df
fdc3296
a5415ef
eb9042a
a6637a9
8808c7c
af11e41
07f0c17
522d34c
2e828a8
374f1c7
35036ea
b45c413
00bb1f2
b16ba2b
fbcd213
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
|
||
import asyncio | ||
import os | ||
from typing import AsyncGenerator, List, Mapping, Optional, Type, Union | ||
from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union | ||
|
||
import numpy as np | ||
|
||
|
@@ -24,6 +24,8 @@ | |
from vllm.utils import cdiv, kill_process_tree | ||
from vllm.v1.engine.core_client import EngineCoreClient | ||
from vllm.v1.engine.output_processor import OutputProcessor | ||
from vllm.v1.engine.parallel_sampling import (ParallelSamplingOutputProcessor, | ||
ParentRequestState) | ||
from vllm.v1.engine.processor import Processor | ||
from vllm.v1.executor.abstract import Executor | ||
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, | ||
|
@@ -50,6 +52,8 @@ def __init__( | |
assert start_engine_loop | ||
|
||
self.model_config = vllm_config.model_config | ||
self.enable_prefix_caching = ( | ||
vllm_config.cache_config.enable_prefix_caching) | ||
|
||
self.log_requests = log_requests | ||
self.log_stats = log_stats | ||
|
@@ -170,7 +174,7 @@ async def add_request( | |
# requests we don't need to send multiple messages to core proc, | ||
# and so we don't need multiple streams which then get | ||
# re-multiplexed in the API server anyhow. | ||
async def generate( | ||
async def _generate( | ||
self, | ||
prompt: PromptType, | ||
sampling_params: SamplingParams, | ||
|
@@ -241,6 +245,131 @@ async def generate( | |
await self.abort(request_id) | ||
raise | ||
|
||
async def _parallel_sampling_child_gen( | ||
self, | ||
child_gen: AsyncGenerator[RequestOutput, None], | ||
output_processor: ParallelSamplingOutputProcessor, | ||
index: int, | ||
) -> AsyncGenerator[RequestOutput, None]: | ||
"""A single parallel sampling child request | ||
output generator. | ||
|
||
Each parallel sampling request triggers at | ||
least two child requests. This generator | ||
yields zero or more request outputs to | ||
return to the caller, as they become | ||
available. | ||
|
||
Args: | ||
child_gen: generator for child request | ||
outputs. | ||
output_processor: transform child request | ||
outputs into parent | ||
request outputs | ||
index: index within the `n` child requests | ||
|
||
Returns: | ||
Yields zero or more request outputs to return | ||
to the caller. | ||
""" | ||
async for out in child_gen: | ||
if req_out := output_processor.process_output(out, index): | ||
yield req_out | ||
|
||
async def _generate_parallel_sampling( | ||
self, | ||
prompt: PromptType, | ||
sampling_params: SamplingParams, | ||
request_id: str, | ||
lora_request: Optional[LoRARequest] = None, | ||
trace_headers: Optional[Mapping[str, str]] = None, | ||
prompt_adapter_request: Optional[PromptAdapterRequest] = None, | ||
priority: int = 0, | ||
) -> AsyncGenerator[RequestOutput, None]: | ||
"""Generation completes for parallel sampling requests.""" | ||
|
||
parent_state = ParentRequestState(request_id, sampling_params) | ||
output_processor = ParallelSamplingOutputProcessor(parent_state) | ||
n = parent_state.n | ||
|
||
# Adapted from sglang: | ||
# https://github.com/sgl-project/sglang/blob/ | ||
# 4fe92bfca5517f3cf5ca967fc5fcfdb7cf335f30/ | ||
# python/sglang/srt/managers/ | ||
# tokenizer_manager.py#L456-L532 | ||
|
||
if self.enable_prefix_caching: | ||
# If engine uses APC, generate a “warmup request” with | ||
# max_tokens=1 which populates the APC | ||
w_sampling_params = parent_state.get_warmup_sampling_params() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think "prefill" or "priming" instead of "warmup" might be a better term to use for this.
Comment on lines
+302
to
+304
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, why do we need this? I think this should be avoided. |
||
async for _ in self._generate( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May be good to do this exhaustion in an async task. Then the subsequent setup can happen in parallel. |
||
prompt, | ||
w_sampling_params, | ||
parent_state.get_warmup_request_id(), | ||
lora_request, | ||
trace_headers, | ||
prompt_adapter_request, | ||
priority, | ||
Comment on lines
+306
to
+312
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Best to include kwarg names here Also could make a dict with the common ones and reuse that below. |
||
): | ||
# Exhaust the generator | ||
pass | ||
|
||
# Aggregate generators for n child requests | ||
gens: List[AsyncGenerator[RequestOutput, None]] = [] | ||
active: Dict[asyncio.Task, int] = {} | ||
seed = sampling_params.seed | ||
for idx in range(n): | ||
c_sampling_params = parent_state.get_child_sampling_params(seed) | ||
if seed is not None: | ||
seed += 1 | ||
Comment on lines
+322
to
+324
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe a bit cleaner to just pass the index to this method and handle the seed logic inside. Actually why not have a single method that returns the child sampling params and request id for given index. |
||
child_gen = self._generate( | ||
prompt, | ||
c_sampling_params, | ||
parent_state.get_child_request_id(idx), | ||
lora_request, | ||
trace_headers, | ||
prompt_adapter_request, | ||
priority, | ||
) | ||
gen = self._parallel_sampling_child_gen(child_gen, | ||
output_processor, idx) | ||
gens.append(gen) | ||
active[asyncio.create_task(gen.__anext__())] = idx # type: ignore | ||
|
||
try: | ||
while active: | ||
done, _ = await asyncio.wait( | ||
active.keys(), return_when=asyncio.FIRST_COMPLETED) | ||
for task in done: | ||
idx = active.pop(task) | ||
try: | ||
result = task.result() | ||
yield result | ||
# Schedule the next result | ||
active[asyncio.create_task( | ||
gens[idx].__anext__())] = idx # type: ignore | ||
except StopAsyncIteration: | ||
continue | ||
finally: | ||
for task in active: | ||
task.cancel() | ||
Comment on lines
+339
to
+355
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the |
||
|
||
def generate( | ||
self, | ||
prompt: PromptType, | ||
sampling_params: SamplingParams, | ||
request_id: str, | ||
lora_request: Optional[LoRARequest] = None, | ||
trace_headers: Optional[Mapping[str, str]] = None, | ||
prompt_adapter_request: Optional[PromptAdapterRequest] = None, | ||
priority: int = 0, | ||
) -> AsyncGenerator[RequestOutput, None]: | ||
n = sampling_params.n | ||
_generate = self._generate if n is None or n == 1 \ | ||
else self._generate_parallel_sampling | ||
return _generate(prompt, sampling_params, request_id, lora_request, | ||
trace_headers, prompt_adapter_request, priority) | ||
|
||
async def _run_output_handler(self): | ||
"""Background loop: pulls from EngineCore and pushes to AsyncStreams.""" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest to move this to a method on
ParentRequestState
, it doesn't need to have theoutput_processor
arg then either.