From 444e4046f127926e0da2c3d18dea5f9b915ccddc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 21 Oct 2024 17:31:44 -0700 Subject: [PATCH] [core] move parallel sampling out from vllm core (#9302) Signed-off-by: charlifu --- tests/entrypoints/openai/test_completion.py | 34 ++++++ vllm/engine/llm_engine.py | 52 +++++++-- vllm/outputs.py | 43 ++++--- vllm/sequence.py | 122 +++++++++++++++++++- 4 files changed, 222 insertions(+), 29 deletions(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index cc72a49ebbbda..f03bdb045f640 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, assert "".join(chunks) == single_output +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): + """Streaming for parallel sampling. + The tokens from multiple samples, are flattened into a single stream, + with an index to indicate which sample the token belongs to. + """ + + prompt = "What is an LLM?" + n = 3 + max_tokens = 5 + + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=max_tokens, + n=n, + stream=True) + chunks: List[List[str]] = [[] for i in range(n)] + finish_reason_count = 0 + async for chunk in stream: + index = chunk.choices[0].index + text = chunk.choices[0].text + chunks[index].append(text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == n + for chunk in chunks: + assert len(chunk) == max_tokens + print("".join(chunk)) + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a90bfce8491fb..25c4e76d9b159 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -44,8 +44,10 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - Sequence, SequenceGroup, SequenceGroupMetadata, - SequenceGroupOutput, SequenceStatus) + ParallelSampleSequenceGroup, Sequence, + SequenceGroup, SequenceGroupBase, + SequenceGroupMetadata, SequenceGroupOutput, + SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -474,6 +476,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: ), )) + self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -642,7 +646,10 @@ def _add_processed_request( prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> None: + ) -> SequenceGroup: + """Add a processed request to the engine's request pool. + return the created sequence group. + """ self._validate_model_inputs(processed_inputs) # Create the sequences. block_size = self.cache_config.block_size @@ -696,6 +703,8 @@ def _add_processed_request( min_cost_scheduler = self.scheduler[costs.index(min(costs))] min_cost_scheduler.add_seq_group(seq_group) + return seq_group + def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() @@ -711,7 +720,7 @@ def add_request( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - ) -> None: + ) -> Optional[SequenceGroup]: ... @overload @@ -725,7 +734,7 @@ def add_request( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - ) -> None: + ) -> Optional[SequenceGroup]: ... @deprecate_kwargs( @@ -744,7 +753,7 @@ def add_request( priority: int = 0, *, inputs: Optional[PromptType] = None, # DEPRECATED - ) -> None: + ) -> Optional[SequenceGroup]: """Add a request to the engine's request pool. The request is added to the request pool and will be processed by the @@ -788,6 +797,22 @@ def add_request( >>> # continue the request processing >>> ... """ + + if isinstance(params, SamplingParams) and params.n > 1: + ParallelSampleSequenceGroup.add_request( + request_id, + self, + params, + prompt=prompt, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + inputs=inputs, + ) + return None + if inputs is not None: prompt = inputs assert prompt is not None and params is not None @@ -818,7 +843,7 @@ def add_request( processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get( "mm_processor_kwargs") - self._add_processed_request( + return self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, params=params, @@ -1135,7 +1160,9 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) request_output = RequestOutputFactory.create( - seq_group, use_cache=self.use_cached_outputs) + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) @@ -1175,7 +1202,9 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) request_output = RequestOutputFactory.create( - seq_group, use_cache=self.use_cached_outputs) + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) @@ -1194,7 +1223,10 @@ def _process_model_outputs(self, continue request_output = RequestOutputFactory.create( - seq_group, use_cache=self.use_cached_outputs) + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs, + ) if request_output: ctx.request_outputs.append(request_output) diff --git a/vllm/outputs.py b/vllm/outputs.py index 07650241cb638..951976310e7ae 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,13 +1,13 @@ import time from dataclasses import dataclass -from typing import List, Optional +from typing import Dict, List, Optional from typing import Sequence as GenericSequence from typing import Union from vllm.lora.request import LoRARequest from vllm.sampling_params import RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, - SequenceGroup, SequenceStatus) + SequenceGroup, SequenceGroupBase, SequenceStatus) @dataclass @@ -114,14 +114,28 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids @classmethod - def from_seq_group(cls, seq_group: SequenceGroup, - use_cache: bool) -> Optional["RequestOutput"]: + def from_seq_group( + cls, seq_group: SequenceGroup, use_cache: bool, + seq_id_to_seq_group: Dict[str, SequenceGroupBase] + ) -> Optional["RequestOutput"]: + finished = seq_group.is_finished() + + if seq_group.request_id in seq_id_to_seq_group: + group: SequenceGroupBase = seq_id_to_seq_group[ + seq_group.request_id] + if finished: + group.finish_seq(seq_group) + assembled_seq_group = group.maybe_assemble_group(seq_group) + if assembled_seq_group is None: + return None + return cls.from_seq_group(assembled_seq_group, use_cache, + seq_id_to_seq_group) + sampling_params = seq_group.sampling_params if sampling_params is None: raise ValueError( "Sampling parameters are missing for a CompletionRequest.") - finished = seq_group.is_finished() if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( not finished): return None @@ -136,15 +150,7 @@ def from_seq_group(cls, seq_group: SequenceGroup, outputs=[], finished=False) - seqs = seq_group.get_seqs() - if len(seqs) == 1: - top_n_seqs = seqs - else: - # Get the top-n sequences. - n = sampling_params._real_n or sampling_params.n - sorting_key = lambda seq: seq.get_cumulative_logprob() - sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) - top_n_seqs = sorted_seqs[:n] + top_n_seqs = seq_group.get_seqs() # Create the outputs. # NOTE: We need omit logprobs here explicitly because the sequence @@ -208,7 +214,7 @@ def from_seq_group(cls, seq_group: SequenceGroup, else: output = CompletionOutput( - seqs.index(seq), output_text, [output_token_ids] + top_n_seqs.index(seq), output_text, [output_token_ids] if isinstance(output_token_ids, int) else output_token_ids, seq.get_cumulative_logprob() if include_logprobs else None, output_logprobs, @@ -309,10 +315,13 @@ def __repr__(self): class RequestOutputFactory: @staticmethod - def create(seq_group: SequenceGroup, use_cache: bool = False): + def create(seq_group: SequenceGroup, + seq_id_to_seq_group: Dict[str, SequenceGroupBase], + use_cache: bool = False): # Determine the type based on a condition, for example: if hasattr(seq_group, 'embeddings') and seq_group.embeddings is not None: return EmbeddingRequestOutput.from_seq_group(seq_group) else: - return RequestOutput.from_seq_group(seq_group, use_cache) + return RequestOutput.from_seq_group(seq_group, use_cache, + seq_id_to_seq_group) diff --git a/vllm/sequence.py b/vllm/sequence.py index e580d69ec5afb..93f58f00ef77b 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from array import array from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import cached_property, reduce from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional from typing import Sequence as GenericSequence @@ -17,7 +17,7 @@ from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if TYPE_CHECKING: @@ -1401,3 +1401,121 @@ def clone( last_sampled_token_ids=self.last_sampled_token_ids.clone() if self.last_sampled_token_ids is not None else None, async_callback=self.async_callback) + + +@dataclass +class SequenceGroupBase: + group_id: str # the original request id before splitting + + assembled_seq_group: Optional[SequenceGroup] = None + + # seq id to a unique index inside this group + seq_id_to_index: Dict[str, int] = field(default_factory=dict) + + # seq ids to be finished + to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict) + + # seq id to finished sequences + finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict) + + streaming: bool = False + + output_produced: bool = False + + @staticmethod + def add_request(request_id: str, engine, params, *args, **kwargs): + """When we are ready to add a request with request_id and params + into the engine, we can split the request into multiple requests. + """ + raise NotImplementedError + + def finish_seq(self, seq: SequenceGroup): + """The sequence `seq` finishes, we should record the information. + """ + del self.to_be_finished[seq.request_id] + self.finished_reqs[seq.request_id] = seq + + def maybe_assemble_group( + self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: + """Assemble the sequence group, for producing the final + output, or adding request in the engine again. + """ + raise NotImplementedError + + +class ParallelSampleSequenceGroup(SequenceGroupBase): + + @staticmethod + def add_request(request_id: str, engine, params, **kwargs): + original_params = params + params = copy.deepcopy(original_params) + params.n = 1 + group = ParallelSampleSequenceGroup(request_id) + seqs = [] + for i in range(original_params.n): + request_id_i = f"{request_id}_parallel_sample_{i}" + group.seq_id_to_index[request_id_i] = i + seq_group = engine.add_request( + request_id_i, + params=params, + **kwargs, + ) # type: ignore + assert seq_group is not None + engine.seq_id_to_seq_group[request_id_i] = group + group.to_be_finished[request_id_i] = seq_group + seqs.append(seq_group.seqs[0]) + + # for parallel sampling, the `assembled_seq_group` is always + # available, since we have all the sequences ready, and they + # will not change. + group.assembled_seq_group = SequenceGroup( + request_id=request_id, + seqs=seqs, + arrival_time=seq_group.arrival_time, + sampling_params=original_params, + lora_request=seq_group.lora_request, + embeddings=seq_group.embeddings, + pooling_params=seq_group.pooling_params, + encoder_seq=seq_group.encoder_seq, + trace_headers=seq_group.trace_headers, + prompt_adapter_request=seq_group.prompt_adapter_request, + priority=seq_group.priority, + ) + + group.streaming = params.output_kind == RequestOutputKind.DELTA + group.output_produced = False + + def maybe_assemble_group( + self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: + + # in the streaming mode, we will return the assembled sequence + # for the first sequence, and then return None for the rest of + # sequences + if self.streaming: + if self.seq_id_to_index[seq_group.request_id] == 0: + return self.assembled_seq_group + return None + + # in the non-streaming mode, we will return the assembled sequence + # once after all sequences finish, and then return None for the + # rest of the time + + if len(self.to_be_finished) > 0: + return None + + assert self.assembled_seq_group is not None + params = self.assembled_seq_group.sampling_params + assert isinstance(params, SamplingParams) + if not self.output_produced: + self.output_produced = True + if params._real_n is not None: + # Get the top-n sequences. + n = params._real_n or params.n + seqs = self.assembled_seq_group.seqs + sorting_key = lambda seq: seq.get_cumulative_logprob() + sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) + top_n_seqs = sorted_seqs[:n] + self.assembled_seq_group.seqs = top_n_seqs + return self.assembled_seq_group + if self.output_produced: + return None