Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] V1 engine implements parallel sampling, 1/2: AsyncLLM support #10980

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions tests/v1/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,65 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
assert "".join(chunks) == single_output


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
model_name: str):
"""Parallel sampling without streaming.
A single request output contains a list of completions.
"""

prompt = "What is an LLM?"
n = 3
max_tokens = 5

completion = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
stream=False)

for choice in completion.choices:
assert choice.finish_reason is not None


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling.
The tokens from multiple samples, are flattened into a single stream,
with an index to indicate which sample the token belongs to.
"""

prompt = "What is an LLM?"
n = 3
max_tokens = 5

stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
stream=True)
chunks: List[List[str]] = [[] for i in range(n)]
finish_reason_count = 0
async for chunk in stream:
index = chunk.choices[0].index
text = chunk.choices[0].text
chunks[index].append(text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert finish_reason_count == n
for chunk in chunks:
assert len(chunk) == max_tokens
print("".join(chunk))


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
Expand Down
133 changes: 131 additions & 2 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import os
from typing import AsyncGenerator, List, Mapping, Optional, Type, Union
from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union

import numpy as np

Expand All @@ -24,6 +24,8 @@
from vllm.utils import cdiv, kill_process_tree
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import (ParallelSamplingOutputProcessor,
ParentRequestState)
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
Expand All @@ -50,6 +52,8 @@ def __init__(
assert start_engine_loop

self.model_config = vllm_config.model_config
self.enable_prefix_caching = (
vllm_config.cache_config.enable_prefix_caching)

self.log_requests = log_requests
self.log_stats = log_stats
Expand Down Expand Up @@ -170,7 +174,7 @@ async def add_request(
# requests we don't need to send multiple messages to core proc,
# and so we don't need multiple streams which then get
# re-multiplexed in the API server anyhow.
async def generate(
async def _generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
Expand Down Expand Up @@ -241,6 +245,131 @@ async def generate(
await self.abort(request_id)
raise

async def _parallel_sampling_child_gen(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to move this to a method on ParentRequestState, it doesn't need to have the output_processor arg then either.

self,
child_gen: AsyncGenerator[RequestOutput, None],
output_processor: ParallelSamplingOutputProcessor,
index: int,
) -> AsyncGenerator[RequestOutput, None]:
"""A single parallel sampling child request
output generator.

Each parallel sampling request triggers at
least two child requests. This generator
yields zero or more request outputs to
return to the caller, as they become
available.

Args:
child_gen: generator for child request
outputs.
output_processor: transform child request
outputs into parent
request outputs
index: index within the `n` child requests

Returns:
Yields zero or more request outputs to return
to the caller.
"""
async for out in child_gen:
if req_out := output_processor.process_output(out, index):
yield req_out

async def _generate_parallel_sampling(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generation completes for parallel sampling requests."""

parent_state = ParentRequestState(request_id, sampling_params)
output_processor = ParallelSamplingOutputProcessor(parent_state)
n = parent_state.n

# Adapted from sglang:
# https://github.com/sgl-project/sglang/blob/
# 4fe92bfca5517f3cf5ca967fc5fcfdb7cf335f30/
# python/sglang/srt/managers/
# tokenizer_manager.py#L456-L532

if self.enable_prefix_caching:
# If engine uses APC, generate a “warmup request” with
# max_tokens=1 which populates the APC
w_sampling_params = parent_state.get_warmup_sampling_params()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "prefill" or "priming" instead of "warmup" might be a better term to use for this.

Comment on lines +302 to +304
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, why do we need this? I think this should be avoided.

async for _ in self._generate(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be good to do this exhaustion in an async task. Then the subsequent setup can happen in parallel.

prompt,
w_sampling_params,
parent_state.get_warmup_request_id(),
lora_request,
trace_headers,
prompt_adapter_request,
priority,
Comment on lines +306 to +312
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Best to include kwarg names here lora_request=lora_request, etc.

Also could make a dict with the common ones and reuse that below.

):
# Exhaust the generator
pass

# Aggregate generators for n child requests
gens: List[AsyncGenerator[RequestOutput, None]] = []
active: Dict[asyncio.Task, int] = {}
seed = sampling_params.seed
for idx in range(n):
c_sampling_params = parent_state.get_child_sampling_params(seed)
if seed is not None:
seed += 1
Comment on lines +322 to +324
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a bit cleaner to just pass the index to this method and handle the seed logic inside.

Actually why not have a single method that returns the child sampling params and request id for given index.

child_gen = self._generate(
prompt,
c_sampling_params,
parent_state.get_child_request_id(idx),
lora_request,
trace_headers,
prompt_adapter_request,
priority,
)
gen = self._parallel_sampling_child_gen(child_gen,
output_processor, idx)
gens.append(gen)
active[asyncio.create_task(gen.__anext__())] = idx # type: ignore

try:
while active:
done, _ = await asyncio.wait(
active.keys(), return_when=asyncio.FIRST_COMPLETED)
for task in done:
idx = active.pop(task)
try:
result = task.result()
yield result
# Schedule the next result
active[asyncio.create_task(
gens[idx].__anext__())] = idx # type: ignore
except StopAsyncIteration:
continue
finally:
for task in active:
task.cancel()
Comment on lines +339 to +355
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the merge_async_generators util function can be used here.


def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
n = sampling_params.n
_generate = self._generate if n is None or n == 1 \
else self._generate_parallel_sampling
return _generate(prompt, sampling_params, request_id, lora_request,
trace_headers, prompt_adapter_request, priority)

async def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""

Expand Down
Loading