diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 84b634316cb46..7f2636d44a08d 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -1,10 +1,93 @@ # SPDX-License-Identifier: Apache-2.0 +import random +from typing import Dict, List, Optional, Tuple + import pytest from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG from vllm import LLM, SamplingParams +MODEL = "facebook/opt-125m" +DTYPE = "half" + + +@pytest.fixture( + scope="module", + # Prefix caching + params=[False, True]) +def vllm_model(vllm_runner, request): + """VllmRunner test fixture parameterized by APC.""" + enable_prefix_caching = request.param + with vllm_runner( + MODEL, + dtype=DTYPE, + max_model_len=128, + enforce_eager=True, + enable_prefix_caching=enable_prefix_caching, + gpu_memory_utilization=0.5, + ) as vllm_model: + # VllmRunner instance is cleaned up after test. + yield vllm_model + + +def _get_test_sampling_params( + prompt_list: List[str], + seed: Optional[int] = 42, +) -> Tuple[List[SamplingParams], List[int]]: + """Generate random sampling params for a batch.""" + + def get_mostly_n_gt1() -> int: + """Mostly n \in [2,20], ~1/3 n=1""" + x = random.randint(0, 28) + if x < 10: + return 1 + else: + return x - 8 + + n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))] + # High temperature to maximize the chance of unique completions + return [ + SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed) + for n in n_list + ], n_list + + +def test_parallel_sampling(monkeypatch, vllm_model, example_prompts) -> None: + """Test passes if parallel sampling `n>1` yields `n` unique completions. + + Args: + monkeypatch: test fixture for modifying text env, scoped to the test. + vllm_model: VllmRunner instance under test. + example_prompt: test fixture providing prompts for testing. + """ + monkeypatch.setenv("VLLM_USE_V1", "1") + sampling_params_list, n_list = _get_test_sampling_params(example_prompts) + model: LLM = vllm_model.model + outputs = model.generate(example_prompts, sampling_params_list) + + # Validate each request response + for out, n in zip(outputs, n_list): + completion_counts: Dict[str, int] = {} + # Assert correct number of completions + assert len(out.outputs) == n, ( + f"{len(out.outputs)} completions; {n} expected.") + for idx in range(n): + comp = out.outputs[idx] + # Assert correct completion indices + assert comp.index == idx, (f"Index {comp.index}; expected {idx}.") + text = comp.text + completion_counts[text] = completion_counts.get(text, 0) + 1 + # Assert unique completions + if len(completion_counts) != n: + repeats = { + txt: num + for (txt, num) in completion_counts.items() if num > 1 + } + raise AssertionError( + f"{len(completion_counts)} unique completions; expected" + f" {n}. Repeats: {repeats}") + def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch): """Test passes if LLMEngine raises an exception when it is configured @@ -15,7 +98,7 @@ def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch): # TODO(nick): Single-proc to work around a ZMQ shutdown hang for now. monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") with pytest.raises(ValueError) as excinfo: - LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate( + LLM(model=MODEL, enable_prefix_caching=True).generate( "Hello, my name is", SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5)) diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index ef46a16ef3447..35e059ccb5480 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -250,6 +250,108 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, assert "".join(chunks) == single_output +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_parallel_no_streaming(client: openai.AsyncOpenAI, + model_name: str): + """Parallel sampling without streaming. + A single request output contains a list of completions. + """ + + prompt = "What is an LLM?" + n = 3 + max_tokens = 5 + + # High temperature to maximize chance of unique completions. + completion = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=max_tokens, + n=n, + temperature=0.95, + stream=False, + seed=42) + + # Assert `n` completions + num_completions = len(completion.choices) + assert num_completions == n, ( + f"Num completions {num_completions} but expected {n}.") + completion_repeats: Dict[str, int] = {} + for idx, choice in enumerate(completion.choices): + # Assert correct completion index & some finish reason. + assert choice.index == idx, ( + f"Index {choice.index} but expected {idx}.") + assert choice.finish_reason is not None, ( + "None finish_reason is invalid.") + text = choice.text + completion_repeats[text] = completion_repeats.get(text, 0) + 1 + # Assert `n` unique completions + num_unique = len(completion_repeats) + if num_unique != n: + repeats = { + txt: num + for (txt, num) in completion_repeats.items() if num > 1 + } + raise AssertionError( + f"Expected {n} unique completions, got {num_unique};" + f" repeats: {repeats}.") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): + """Streaming for parallel sampling. + The tokens from multiple samples, are flattened into a single stream, + with an index to indicate which sample the token belongs to. + """ + + prompt = "What is an LLM?" + n = 3 + max_tokens = 5 + + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=max_tokens, + n=n, + temperature=0.95, + stream=True, + seed=42) + chunks: List[List[str]] = [[] for i in range(n)] + finish_reason_count = 0 + async for chunk in stream: + index = chunk.choices[0].index + text = chunk.choices[0].text + chunks[index].append(text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # Assert `n` completions with correct finish reasons + assert finish_reason_count == n, ( + f"Expected {n} completions with valid indices and finish_reason.") + completion_repeats: Dict[str, int] = {} + for chunk in chunks: + chunk_len = len(chunk) + # Assert correct number of completion tokens + assert chunk_len == max_tokens, ( + f"max_tokens={max_tokens} but chunk len is {chunk_len}.") + text = "".join(chunk) + completion_repeats[text] = completion_repeats.get(text, 0) + 1 + print(text) + # Assert `n` unique completions + num_unique = len(completion_repeats) + if num_unique != n: + repeats = { + txt: num + for (txt, num) in completion_repeats.items() if num > 1 + } + raise AssertionError(f"{num_unique} unique completions, expected {n};" + f" repeats: {repeats}") + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 1920dbf7a7dc5..a079e721a1d28 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -21,9 +21,10 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import cdiv, kill_process_tree +from vllm.utils import cdiv, kill_process_tree, merge_async_iterators from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor +from vllm.v1.engine.parallel_sampling import ParallelSamplingRequestManager from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, @@ -50,6 +51,8 @@ def __init__( assert start_engine_loop self.model_config = vllm_config.model_config + self.enable_prefix_caching = ( + vllm_config.cache_config.enable_prefix_caching) self.log_requests = log_requests self.log_stats = log_stats @@ -170,7 +173,7 @@ async def add_request( # requests we don't need to send multiple messages to core proc, # and so we don't need multiple streams which then get # re-multiplexed in the API server anyhow. - async def generate( + async def _generate( self, prompt: PromptType, sampling_params: SamplingParams, @@ -241,6 +244,56 @@ async def generate( await self.abort(request_id) raise + async def _generate_parallel_sampling( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + """Generate completions for parallel sampling requests.""" + req_mgr = ParallelSamplingRequestManager(request_id, sampling_params) + n = req_mgr.n + + # Aggregate generators for n child requests + gens: List[AsyncGenerator[RequestOutput, None]] = [] + for idx in range(n): + c_sampling_params = req_mgr.get_child_sampling_params(idx) + child_gen = self._generate( + prompt=prompt, + sampling_params=c_sampling_params, + request_id=req_mgr.get_child_request_id(idx), + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) + gen = req_mgr.parallel_sampling_child_gen(child_gen, idx) + gens.append(gen) + + # Merge generators + async for out in merge_async_iterators(*gens): + yield out[1] # out[0] is index + + def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + n = sampling_params.n + _generate = self._generate if n is None or n == 1 \ + else self._generate_parallel_sampling # handle parallel sampling + return _generate(prompt, sampling_params, request_id, lora_request, + trace_headers, prompt_adapter_request, priority) + async def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c9a4c5369dfd8..cb8802b351a3c 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Mapping, Optional, Type, Union +from typing import Dict, List, Mapping, Optional, Tuple, Type, Union from typing_extensions import TypeVar @@ -21,6 +21,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor +from vllm.v1.engine.parallel_sampling import ParallelSamplingRequestManager from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor @@ -47,6 +48,16 @@ def __init__( self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config + # Bookkeeping for parallel sampling requests + # - parent req ID -> parent request manager + self.parallel_parent_reqs: Dict[str, + ParallelSamplingRequestManager] = {} + # - child req ID -> (child req index, parent req ID) + self.parallel_child_reqs: Dict[str, Tuple[int, str]] = {} + # - flag to reset parallel sampling bookkeeping logic + # between engine runs + self._do_reset_parallel_sampling = False + # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( model_config=vllm_config.model_config, @@ -103,7 +114,10 @@ def from_engine_args( multiprocess_mode=enable_multiprocessing) def get_num_unfinished_requests(self) -> int: - return self.output_processor.get_num_unfinished_requests() + num_core_reqs = self.output_processor.get_num_unfinished_requests() + num_child_reqs = self._num_parallel_sampling_child_requests() + num_parent_reqs = self._num_parallel_sampling_requests() + return num_core_reqs + num_parent_reqs - num_child_reqs def has_unfinished_requests(self) -> bool: return self.output_processor.has_unfinished_requests() @@ -118,6 +132,12 @@ def abort_request(self, request_ids: List[str]) -> None: self.engine_core.abort_requests(request_ids) self.output_processor.abort_requests(request_ids) + def _reset_parallel_sampling(self) -> None: + """Reset parallel sampling logic""" + self.parallel_parent_reqs.clear() + self.parallel_child_reqs.clear() + self._do_reset_parallel_sampling = False + def add_request( self, request_id: str, @@ -129,7 +149,63 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: - + """Add request.""" + if self._do_reset_parallel_sampling: + # Reset parallel sampling logic between + # LLM.generate() calls + self._reset_parallel_sampling() + # Handle parallel sampling requests differently. + _add_request = (self._add_request if params is None + or isinstance(params, PoolingParams) or params.n == 1 + else self._add_request_parallel_sampling) + return _add_request(request_id=request_id, + prompt=prompt, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority) + + def _add_request_parallel_sampling( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + """Add request, `n>1`""" + req_mgr = ParallelSamplingRequestManager(request_id, params) + self.parallel_parent_reqs[request_id] = req_mgr + # Add n child requests with unique request IDs & random seeds and n=1 + for idx in range(req_mgr.n): + c_request_id = req_mgr.get_child_request_id(idx) + self.parallel_child_reqs[c_request_id] = (idx, request_id) + self._add_request(request_id=c_request_id, + prompt=prompt, + params=req_mgr.get_child_sampling_params(idx), + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority) + + def _add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + """Add request, `n=1`""" # 1) Process raw inputs into the request. request = self.processor.process_inputs(request_id, prompt, params, arrival_time, lora_request, @@ -143,7 +219,59 @@ def add_request( # 3) Add the request to EngineCore. self.engine_core.add_request(request) + def _aggregate_parallel_sampling_outputs( + self, + outputs: List[RequestOutput], + ) -> List[RequestOutput]: + """Build parallel sampling request outputs. + + Extract child request outputs, aggregate them + into parent request output, and return parent + output when complete. + + Do not modify `n=1` requests. + + Args: + outputs: step request outputs. Mix of child request + outputs & `n=1` request outputs. + + Return: + List of parallel sampling parent request outputs & + unmodified `n=1` request outputs passed-thru from input. + """ + agg_outputs = [] + for c_out in outputs: + c_req_id = c_out.request_id + if cdx_req_id := self.parallel_child_reqs.get(c_req_id, None): + # For each parallel sampling child request output: + (cdx, req_id) = cdx_req_id + req_mgr = self.parallel_parent_reqs[req_id] + # Update parallel sampling request + if out := req_mgr._process_output(c_out, cdx): + # Return parent request output if complete; + # cleanup parent request bookkeeping. + agg_outputs.append(out) + del self.parallel_parent_reqs[req_id] + # Cleanup child request bookkeeping. + del self.parallel_child_reqs[c_req_id] + else: + # Not a parallel sampling request output + agg_outputs.append(c_out) + return agg_outputs + + def _num_parallel_sampling_requests(self) -> int: + return len(self.parallel_parent_reqs) + + def _num_parallel_sampling_child_requests(self) -> int: + return len(self.parallel_child_reqs) + def step(self) -> List[RequestOutput]: + num_parallel_reqs = self._num_parallel_sampling_requests() + + # Ensure that parallel sampling logic gets reset after the + # engine finishes processing this batch + self._do_reset_parallel_sampling = (True if num_parallel_reqs > 0 else + self._do_reset_parallel_sampling) # 1) Get EngineCoreOutput from the EngineCore. outputs = self.engine_core.get_output() @@ -155,7 +283,12 @@ def step(self) -> List[RequestOutput]: # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) - return processed_outputs.request_outputs + request_outputs = processed_outputs.request_outputs + if num_parallel_reqs > 0 and len(request_outputs) > 0: + # Process parallel sampling child request outputs + return self._aggregate_parallel_sampling_outputs(request_outputs) + else: + return request_outputs def get_model_config(self): return self.model_config diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py new file mode 100644 index 0000000000000..eb16404b35f81 --- /dev/null +++ b/vllm/v1/engine/parallel_sampling.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: Apache-2.0 + +from copy import copy +from typing import AsyncGenerator, Optional + +from vllm.outputs import RequestOutput +from vllm.sampling_params import RequestOutputKind, SamplingParams + + +class ParallelSamplingRequestManager: + """Info, state & processing for parallel sampling request. + + Store parent request ID and sampling params. + Facilitate generating child request sampling params. + Transform child request outputs into parent request + outputs. + When stream mode is disabled, then `self.request_output` + aggregates child request completions. + """ + + request_id: str + sampling_params: SamplingParams + cached_child_sampling_params: Optional[SamplingParams] + request_output: Optional[RequestOutput] = None + + def __init__(self, request_id: str, + sampling_params: SamplingParams) -> None: + self.request_id = request_id + self.sampling_params = sampling_params + self.cached_child_sampling_params = None + + def get_child_sampling_params( + self, + index: int, + ) -> SamplingParams: + """Efficiently obtain child `sampling_params` + + If `sampling_params.seed` is not `None` then + each child request requires a unique clone of + parent `sampling_params` with a unique seed. + + Args: + index: index within `n` child requests + + Returns: + Child `sampling_params` instance. + """ + seed = self.sampling_params.seed + if seed is None and self.cached_child_sampling_params: + # Reuse child sampling_params data structure + return self.cached_child_sampling_params + # Build child sampling_params + c_sampling_params = copy(self.sampling_params) + c_sampling_params.n = 1 + if seed is None: + # Cache child sampling_params for later reuse + self.cached_child_sampling_params = c_sampling_params + else: + # Each child gets a clone with a unique seed + c_sampling_params.seed = seed + index + return c_sampling_params + + def _add_output( + self, + child_req_output: RequestOutput, + index: int, + ) -> None: + """Aggregate a parallel sampling child + request output. + + Non-stream-mode (`output_kind == FINAL_ONLY`) + only. Inject correct parent request ID and + completion index. + + Args: + child_req_output: a single request output + from a parallel sampling + child request. + index: index within `n` child + """ + new_completion = child_req_output.outputs[0] + new_completion.index = index + if self.request_output is None: + # Save the first request output; reinstate + # original request ID; metrics are not + # supported for parallel sampling + child_req_output.request_id = self.request_id + child_req_output.metrics = None + self.request_output = child_req_output + else: + # Aggregate additional completion into request output + # Note: will be sorted by index later + self.request_output.outputs.append(new_completion) + + def _get_parent_request_output(self) -> RequestOutput: + """Invariant: parent completion outputs sorted by index""" + assert self.request_output is not None + self.request_output.outputs = sorted(self.request_output.outputs, + key=lambda x: x.index) + return self.request_output + + def get_child_request_id( + self, + index: int, + ) -> str: + return str(index) + "_" + self.request_id + + def _process_output( + self, + child_req_output: RequestOutput, + index: int, + ) -> Optional[RequestOutput]: + """Filter, aggregate and transform parallel sampling + child request outputs. + + If the parent request has `stream=false` + (`output_kind == FINAL_ONLY`), each child will also have + `output_kind == FINAL_ONLY`. All child request outputs + must be aggregated into a single request output, with + multiple completions. This request output is only returned + once `n` completions are aggregated. + + If the parent request has `stream=true` + (`output_kind == DELTA`), each child will also have + `output_kind == DELTA`. All child request outputs + must be streamed directly to the caller. + + Args: + child_req_output: a single child request output + index: index within `n` child requests + + Returns: + `None`, unless a processed request output is ready to + send back to the caller. + """ + if self.output_kind != RequestOutputKind.FINAL_ONLY: + # stream=true: return child completions immediately + child_req_output.request_id = self.request_id + child_req_output.outputs[0].index = index + return child_req_output + + # stream=false: aggregate child completions + self._add_output(child_req_output, index) + if self.num_completions == self.n: + # Return aggregated request output after obtaining + # all completions + return self._get_parent_request_output() + return None + + async def parallel_sampling_child_gen( + self, + child_gen: AsyncGenerator[RequestOutput, None], + index: int, + ) -> AsyncGenerator[RequestOutput, None]: + """Output generator for a single parallel sampling + child request. + + Each parallel sampling request triggers at + least two child requests. This generator + yields zero or more request outputs to + return to the caller, as they become + available. + + Args: + child_gen: generator for child request + outputs. + index: index within the `n` child requests + + Returns: + Yields zero or more request outputs to return + to the caller. + """ + async for out in child_gen: + if req_out := self._process_output(out, index): + yield req_out + + @property + def num_completions(self) -> int: + assert self.request_output is not None + return len(self.request_output.outputs) + + @property + def n(self) -> int: + return self.sampling_params.n + + @property + def output_kind(self) -> RequestOutputKind: + return self.sampling_params.output_kind