Skip to content

Commit

Permalink
[core] move parallel sampling out from vllm core (vllm-project#9302)
Browse files Browse the repository at this point in the history
Signed-off-by: charlifu <charlifu@amd.com>
  • Loading branch information
youkaichao authored and charlifu committed Oct 23, 2024
1 parent fcc8295 commit 444e404
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 29 deletions.
34 changes: 34 additions & 0 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
assert "".join(chunks) == single_output


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling.
The tokens from multiple samples, are flattened into a single stream,
with an index to indicate which sample the token belongs to.
"""

prompt = "What is an LLM?"
n = 3
max_tokens = 5

stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
stream=True)
chunks: List[List[str]] = [[] for i in range(n)]
finish_reason_count = 0
async for chunk in stream:
index = chunk.choices[0].index
text = chunk.choices[0].text
chunks[index].append(text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert finish_reason_count == n
for chunk in chunks:
assert len(chunk) == max_tokens
print("".join(chunk))


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
Expand Down
52 changes: 42 additions & 10 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceGroupOutput, SequenceStatus)
ParallelSampleSequenceGroup, Sequence,
SequenceGroup, SequenceGroupBase,
SequenceGroupMetadata, SequenceGroupOutput,
SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
Expand Down Expand Up @@ -474,6 +476,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
),
))

self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
Expand Down Expand Up @@ -642,7 +646,10 @@ def _add_processed_request(
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> None:
) -> SequenceGroup:
"""Add a processed request to the engine's request pool.
return the created sequence group.
"""
self._validate_model_inputs(processed_inputs)
# Create the sequences.
block_size = self.cache_config.block_size
Expand Down Expand Up @@ -696,6 +703,8 @@ def _add_processed_request(
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
min_cost_scheduler.add_seq_group(seq_group)

return seq_group

def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()

Expand All @@ -711,7 +720,7 @@ def add_request(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
) -> Optional[SequenceGroup]:
...

@overload
Expand All @@ -725,7 +734,7 @@ def add_request(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
) -> Optional[SequenceGroup]:
...

@deprecate_kwargs(
Expand All @@ -744,7 +753,7 @@ def add_request(
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
) -> Optional[SequenceGroup]:
"""Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the
Expand Down Expand Up @@ -788,6 +797,22 @@ def add_request(
>>> # continue the request processing
>>> ...
"""

if isinstance(params, SamplingParams) and params.n > 1:
ParallelSampleSequenceGroup.add_request(
request_id,
self,
params,
prompt=prompt,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
inputs=inputs,
)
return None

if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
Expand Down Expand Up @@ -818,7 +843,7 @@ def add_request(
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
"mm_processor_kwargs")

self._add_processed_request(
return self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
params=params,
Expand Down Expand Up @@ -1135,7 +1160,9 @@ def _process_model_outputs(self,
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)

Expand Down Expand Up @@ -1175,7 +1202,9 @@ def _process_model_outputs(self,
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)

Expand All @@ -1194,7 +1223,10 @@ def _process_model_outputs(self,
continue

request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs,
)
if request_output:
ctx.request_outputs.append(request_output)

Expand Down
43 changes: 26 additions & 17 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import time
from dataclasses import dataclass
from typing import List, Optional
from typing import Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Union

from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceStatus)
SequenceGroup, SequenceGroupBase, SequenceStatus)


@dataclass
Expand Down Expand Up @@ -114,14 +114,28 @@ def __init__(
self.encoder_prompt_token_ids = encoder_prompt_token_ids

@classmethod
def from_seq_group(cls, seq_group: SequenceGroup,
use_cache: bool) -> Optional["RequestOutput"]:
def from_seq_group(
cls, seq_group: SequenceGroup, use_cache: bool,
seq_id_to_seq_group: Dict[str, SequenceGroupBase]
) -> Optional["RequestOutput"]:
finished = seq_group.is_finished()

if seq_group.request_id in seq_id_to_seq_group:
group: SequenceGroupBase = seq_id_to_seq_group[
seq_group.request_id]
if finished:
group.finish_seq(seq_group)
assembled_seq_group = group.maybe_assemble_group(seq_group)
if assembled_seq_group is None:
return None
return cls.from_seq_group(assembled_seq_group, use_cache,
seq_id_to_seq_group)

sampling_params = seq_group.sampling_params
if sampling_params is None:
raise ValueError(
"Sampling parameters are missing for a CompletionRequest.")

finished = seq_group.is_finished()
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
not finished):
return None
Expand All @@ -136,15 +150,7 @@ def from_seq_group(cls, seq_group: SequenceGroup,
outputs=[],
finished=False)

seqs = seq_group.get_seqs()
if len(seqs) == 1:
top_n_seqs = seqs
else:
# Get the top-n sequences.
n = sampling_params._real_n or sampling_params.n
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
top_n_seqs = seq_group.get_seqs()

# Create the outputs.
# NOTE: We need omit logprobs here explicitly because the sequence
Expand Down Expand Up @@ -208,7 +214,7 @@ def from_seq_group(cls, seq_group: SequenceGroup,

else:
output = CompletionOutput(
seqs.index(seq), output_text, [output_token_ids]
top_n_seqs.index(seq), output_text, [output_token_ids]
if isinstance(output_token_ids, int) else output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None,
output_logprobs,
Expand Down Expand Up @@ -309,10 +315,13 @@ def __repr__(self):
class RequestOutputFactory:

@staticmethod
def create(seq_group: SequenceGroup, use_cache: bool = False):
def create(seq_group: SequenceGroup,
seq_id_to_seq_group: Dict[str, SequenceGroupBase],
use_cache: bool = False):
# Determine the type based on a condition, for example:
if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None:
return EmbeddingRequestOutput.from_seq_group(seq_group)
else:
return RequestOutput.from_seq_group(seq_group, use_cache)
return RequestOutput.from_seq_group(seq_group, use_cache,
seq_id_to_seq_group)
Loading

0 comments on commit 444e404

Please sign in to comment.