From 5404c19f92bb378af930702d24e8158290cb45d9 Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Wed, 18 Sep 2024 11:38:43 -0400 Subject: [PATCH] [Core] *Prompt* logprobs support in Multi-step (#8199) --- tests/conftest.py | 84 +++++++++++------- tests/models/utils.py | 108 +++++++++++++++++++++-- tests/multi_step/test_correctness_llm.py | 92 +++++++++++++++++++ tests/utils.py | 3 +- vllm/worker/multi_step_model_runner.py | 72 ++++++++++----- 5 files changed, 300 insertions(+), 59 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e9c7fc7bf9c67..c2616bcf7091c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,8 @@ BatchFeature) from transformers.models.auto.auto_factory import _BaseAutoModelClass +from tests.models.utils import (TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs) from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset @@ -33,7 +35,6 @@ to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput -from vllm.sequence import SampleLogprobs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, identity, is_cpu) @@ -469,7 +470,7 @@ def generate_greedy_logprobs_limit( audios: Optional[PromptAudioInput] = None, videos: Optional[List[np.ndarray]] = None, **kwargs: Any, - ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + ) -> List[TokensTextLogprobs]: all_logprobs: List[List[Dict[int, float]]] = [] all_output_ids: List[List[int]] = [] all_output_strs: List[str] = [] @@ -525,7 +526,7 @@ def generate_encoder_decoder_greedy_logprobs_limit( max_tokens: int, num_logprobs: int, **kwargs: Any, - ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + ) -> List[TokensTextLogprobs]: ''' Greedy logprobs generation for vLLM encoder/decoder models ''' @@ -653,14 +654,16 @@ def generate( @staticmethod def _final_steps_generate_w_logprobs( req_outputs: List[RequestOutput], - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: - outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] + ) -> List[TokensTextLogprobsPromptLogprobs]: + outputs: List[TokensTextLogprobsPromptLogprobs] = [] for req_output in req_outputs: + assert len(req_output.outputs) > 0 for sample in req_output.outputs: output_str = sample.text output_ids = list(sample.token_ids) output_logprobs = sample.logprobs - outputs.append((output_ids, output_str, output_logprobs)) + outputs.append((output_ids, output_str, output_logprobs, + req_output.prompt_logprobs)) return outputs def generate_w_logprobs( @@ -670,7 +673,8 @@ def generate_w_logprobs( images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + ) -> Union[List[TokensTextLogprobs], + List[TokensTextLogprobsPromptLogprobs]]: assert sampling_params.logprobs is not None if images is not None: @@ -695,13 +699,20 @@ def generate_w_logprobs( req_outputs = self.model.generate(inputs, sampling_params=sampling_params) - return self._final_steps_generate_w_logprobs(req_outputs) + + toks_str_logsprobs_prompt_logprobs = ( + self._final_steps_generate_w_logprobs(req_outputs)) + # Omit prompt logprobs if not required by sampling params + return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] + if sampling_params.prompt_logprobs is None else + toks_str_logsprobs_prompt_logprobs) def generate_encoder_decoder_w_logprobs( self, encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], sampling_params: SamplingParams, - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + ) -> Union[List[TokensTextLogprobs], + List[TokensTextLogprobsPromptLogprobs]]: ''' Logprobs generation for vLLM encoder/decoder models ''' @@ -709,7 +720,12 @@ def generate_encoder_decoder_w_logprobs( assert sampling_params.logprobs is not None req_outputs = self.model.generate(encoder_decoder_prompts, sampling_params=sampling_params) - return self._final_steps_generate_w_logprobs(req_outputs) + toks_str_logsprobs_prompt_logprobs = ( + self._final_steps_generate_w_logprobs(req_outputs)) + # Omit prompt logprobs if not required by sampling params + return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] + if sampling_params.prompt_logprobs is None else + toks_str_logsprobs_prompt_logprobs) def generate_greedy( self, @@ -727,44 +743,48 @@ def generate_greedy_logprobs( prompts: List[str], max_tokens: int, num_logprobs: int, + num_prompt_logprobs: Optional[int] = None, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, stop_token_ids: Optional[List[int]] = None, - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: - greedy_logprobs_params = SamplingParams(temperature=0.0, - max_tokens=max_tokens, - logprobs=num_logprobs, - stop_token_ids=stop_token_ids) - outputs = self.generate_w_logprobs(prompts, - greedy_logprobs_params, - images=images, - audios=audios, - videos=videos) - - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] + ) -> Union[List[TokensTextLogprobs], + List[TokensTextLogprobsPromptLogprobs]]: + greedy_logprobs_params = SamplingParams( + temperature=0.0, + max_tokens=max_tokens, + logprobs=num_logprobs, + prompt_logprobs=(num_prompt_logprobs), + stop_token_ids=stop_token_ids) + + return self.generate_w_logprobs(prompts, + greedy_logprobs_params, + images=images, + audios=audios, + videos=videos) def generate_encoder_decoder_greedy_logprobs( self, encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, num_logprobs: int, - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: - greedy_logprobs_params = SamplingParams(temperature=0.0, - use_beam_search=False, - max_tokens=max_tokens, - logprobs=num_logprobs) + num_prompt_logprobs: Optional[int] = None, + ) -> Union[List[TokensTextLogprobs], + List[TokensTextLogprobsPromptLogprobs]]: + greedy_logprobs_params = SamplingParams( + temperature=0.0, + use_beam_search=False, + max_tokens=max_tokens, + logprobs=num_logprobs, + prompt_logprobs=(num_prompt_logprobs), + ) ''' Greedy logprobs generation for vLLM encoder/decoder models ''' - outputs = self.generate_encoder_decoder_w_logprobs( + return self.generate_encoder_decoder_w_logprobs( encoder_decoder_prompts, greedy_logprobs_params) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] - def generate_beam_search( self, prompts: List[str], diff --git a/tests/models/utils.py b/tests/models/utils.py index 93ec03995094b..8e31a1d6eefed 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,7 +1,7 @@ import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union -from vllm.sequence import Logprob, SampleLogprobs +from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs TokensText = Tuple[List[int], str] @@ -34,20 +34,47 @@ def check_outputs_equal( assert output_ids_0 == output_ids_1, fail_msg +# Representation of generated sequence as a tuple of +# * Token ID list +# * String +# * List of top sample logprobs for each sampled token +# +# Assumes prompt logprobs were not requested. TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]]] -# Allow for tokens to be represented as str's rather than IDs +# Allow for tokens to be represented as str's rather than IDs; +# tuple of +# * Token string representations list +# * String +# * Optional list of top sample logprobs for each sampled token +# +# Assumes prompt logprobs were not requested. TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]], List[Dict[str, Logprob]]]]] +# Representation of generated sequence as a tuple of +# * Token ID list +# * String +# * Optional list of top sample logprobs for each sampled token +# * Optional list of top prompt logprobs for each prompt token +# +# Allows prompt logprobs to be requested. +TokensTextLogprobsPromptLogprobs = Tuple[ + List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]], + Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]] + def check_logprobs_close( *, - outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]], - outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]], + outputs_0_lst: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs, + TextTextLogprobs]], + outputs_1_lst: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs, + TextTextLogprobs]], name_0: str, name_1: str, num_outputs_0_skip_tokens: int = 0, @@ -57,6 +84,18 @@ def check_logprobs_close( """Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. + How sample logprobs are compared: + * `always_check_logprobs == True`: set of highest-logprob token ids + must match between seq0 and seq1 at all sampled token offsets + * `always_check_logprobs == False`: highest-logprob token ids are + only compared at sampled token offsets for which generated token + ids don't match + + Prompt logprobs must be provided either for both input sequences, or + for neither. If prompt logprobs are provided, then highest-logprob + prompt token ids must match between seq0 and seq1 at all prompt token + offsets. + Args: outputs_0_lst: First sequence to compare outputs_0_lst: Second sequence to compare @@ -78,8 +117,65 @@ def check_logprobs_close( for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)): - output_ids_0, output_str_0, logprobs_0 = outputs_0 - output_ids_1, output_str_1, logprobs_1 = outputs_1 + assert len(outputs_0) == len(outputs_1) + if len(outputs_0) == 3: + assert len(outputs_1) == 3 + # Break out tokens, text & sample logprobs + # (prompt logprobs were not provided) + output_ids_0, output_str_0, logprobs_0 = outputs_0 + output_ids_1, output_str_1, logprobs_1 = outputs_1 + elif len(outputs_0) == 4: + assert len(outputs_1) == 4 + # Break out tokens, text, sample logprobs & prompt logprobs + ( + output_ids_0, + output_str_0, + logprobs_0, + prompt_logprobs_0, + ) = outputs_0 + ( + output_ids_1, + output_str_1, + logprobs_1, + prompt_logprobs_1, + ) = outputs_1 + + # Test prompt logprobs closeness + if (prompt_logprobs_0 is not None + and prompt_logprobs_1 is not None): + # Both sequences' prompt logprobs lists are not `None`` + # (although individual list elements may be `None`); + # for each token's logprobs: + for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate( + zip(prompt_logprobs_0, prompt_logprobs_1)): + fail_msg = ( + f"Prompt logprobs test:" + f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}" + f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}") + + if logprobs_elem_0 is None: + # If the seq 0 token's logprobs are `None`, + # the seq 1 token's logprobs must be `None` + assert logprobs_elem_1 is None, fail_msg + else: + # If the seq 0 token's logprobs are not `None`, + # the seq 1 token's logprobs must not be `None` + assert logprobs_elem_1 is not None, fail_msg + # Logprobs check: top-k token choices must be the same + assert (set(logprobs_elem_0.keys()) == set( + logprobs_elem_1.keys())), fail_msg + else: + # Both sequence logprobs lists must be `None` + fail_msg = (f"Prompt logprobs test:" + f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}" + f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}") + + assert (prompt_logprobs_0 is None + and prompt_logprobs_1 is None), fail_msg + else: + raise ValueError(f"Outputs tuple must have 3 or 4 elements but " + f"{len(outputs_0)} elements were provided: " + f"{outputs_0}") if logprobs_0 is None: logprobs_0 = [None] * len(output_ids_0) diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index 24ebb60a9cbfd..c5dc81cc25622 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -100,3 +100,95 @@ def test_multi_step_llm( name_0="hf", name_1="vllm", ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("tp_size", [1]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) +@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)]) +def test_multi_step_llm_w_prompt_logprobs( + vllm_runner, + example_prompts, + model: str, + dtype: str, + tp_size: int, + max_tokens: int, + enforce_eager: int, + num_scheduler_steps: int, + num_prompts: int, + num_logprobs: Optional[int], + num_prompt_logprobs: Optional[int], +) -> None: + """Test prompt logprobs with multi-step scheduling via sync LLM Engine. + + Set up a vLLM engine instance w/ single-step scheduling as a ground-truth + reference. + + Prompt them with the same example prompts. + + Validate: + * All generated logprobs are all very close + + Args: + hf_runner: HF transformers model runner fixture + vllm_runner: vLLM model runner fixture + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + dtype: tensor datatype for engine to utilize + tp_size: degree of tensor-parallelism + max_tokens: the maximum number of tokens to generate + enforce_eager + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + num_prompts: number of example prompts under test + num_logprobs: corresponds to the `logprobs` argument to the OpenAI + completions endpoint; `None` -> no logprobs + num_prompt_logprobs: number of logprobs to return for each prompt token; + note that this argument is not supported by the + OpenAI completions endpoint. + """ + + prompts = example_prompts + if len(prompts) < num_prompts: + prompts = prompts * ((num_prompts // len(prompts)) + 1) + prompts = prompts[:num_prompts] + assert len(prompts) == num_prompts + + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + use_v2_block_manager=True, + num_scheduler_steps=num_scheduler_steps, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs, + num_prompt_logprobs=num_prompt_logprobs) + + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + ) as vllm_model: + single_step_vllm_outputs = vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs, + num_prompt_logprobs=num_prompt_logprobs) + + check_logprobs_close( + outputs_0_lst=single_step_vllm_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/utils.py b/tests/utils.py index 81442cad78da2..43825e8138362 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -493,6 +493,7 @@ async def completions_with_server_args( ''' outputs = None + max_wait_seconds = 240 * 3 # 240 is default with RemoteOpenAIServer(model_name, server_cli_args, max_wait_seconds=max_wait_seconds) as server: @@ -503,7 +504,7 @@ async def completions_with_server_args( stream=False, max_tokens=5, logprobs=num_logprobs) - assert outputs is not None + assert outputs is not None, "Completion API call failed." return outputs diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index b900eb5a610ff..ebcafbbab119a 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -614,34 +614,66 @@ def _pythonize_sampler_output( frozen_model_input = model_input.frozen_model_input assert frozen_model_input.sampling_metadata is not None + sampling_metadata = frozen_model_input.sampling_metadata # samples generation should have been skipped assert not output.outputs pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries] - # CPU GPU sync - pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False) + # We guarantee output tensors are ready, so it is safe to + # pythonize the sampler output & obtain CPU-side logprobs. + # + # However we should check whether logprobs pythonization may + # be skipped entirely, i.e. because no logprobs were requested + # or pythonization was not deferred. To that end, + # + # * `prompt_logprobs_are_requested_for_prefill` signals that + # there are *any* prefill-phase requests which specify that + # prompt logprobs should be returned. + # + # * `any_logprobs_are_requested` signals that there are any + # requests which (1) specify that sample logprobs should be + # returned, or (2) are in the prefill phase AND specify that + # prompt logprobs should be returned. + # + # Later on, these flags cause adjustments to the pythonization + # process to accommodate logprobs. + + seq_groups = sampling_metadata.seq_groups + prompt_logprobs_are_requested_for_prefill = any([ + sg.sampling_params.prompt_logprobs is not None and sg.is_prompt + for sg in seq_groups + ]) + any_logprobs_are_requested = ( + prompt_logprobs_are_requested_for_prefill + or any([sg.sampling_params.logprobs is not None for sg in seq_groups])) + + if prompt_logprobs_are_requested_for_prefill: + # CPU GPU sync, after gathering *only* sampled tokens (since + # requesting prompt logprobs leads `sampled_token_ids` to + # include prompt token ids in addition to sampled token ids.) + sample_idx_tensor = torch.tensor( + [sdx for sg in seq_groups for sdx in sg.sample_indices]) + pinned_buffer = pinned_buffer.copy_( + sampled_token_ids[sample_idx_tensor, :], non_blocking=False) + else: + # CPU GPU sync + pinned_buffer = pinned_buffer.copy_(sampled_token_ids, + non_blocking=False) # this will not block as the tensors are already on CPU samples_list = pinned_buffer.tolist() - sampling_metadata = frozen_model_input.sampling_metadata - skip_sampler_cpu_output = ( frozen_model_input.sampling_metadata.skip_sampler_cpu_output) - # We are guaranteed output tensors are ready, so it is safe to - # pythonize the sampler output & obtain CPU-side logprobs. - # - # However this computation may be skipped entirely - # if no pythonization was deferred. - seq_groups = sampling_metadata.seq_groups - logprobs_are_requested = any([ - sg.sampling_params.logprobs is not None - or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups - ]) + # *Don't* skip logprobs pythonization *if*: + # * Any requests require logprobs to be returned in this + # iteration AND + # * These requests are being scheduled in a fashion which + # defers pythonization (i.e. multi-step scheduling.) do_pythonize_logprobs = (skip_sampler_cpu_output - and logprobs_are_requested) + and any_logprobs_are_requested) ( prompt_logprobs, sample_logprobs, @@ -666,7 +698,7 @@ def _pythonize_sampler_output( prompt_logprobs[sgdx], sample_logprobs[sgdx], ) - elif logprobs_are_requested: + elif any_logprobs_are_requested: ( group_prompt_logprobs, group_sample_logprobs, @@ -696,7 +728,7 @@ def _pythonize_sampler_output( seq_output.parent_seq_id = seq_ids[parent_id] seq_output.output_token = next_token_id - if logprobs_are_requested: + if any_logprobs_are_requested: seq_output.logprobs = group_sample_logprobs[tdx] else: logprobs = next(iter(seq_output.logprobs.values())) @@ -714,7 +746,7 @@ def _pythonize_sampler_output( seq_outputs.append( SequenceOutput(seq_ids[parent_id], next_token_id, (group_sample_logprobs[tdx] - if logprobs_are_requested else { + if any_logprobs_are_requested else { next_token_id: Logprob(logprob=float('inf'), rank=None, @@ -722,12 +754,12 @@ def _pythonize_sampler_output( }))) if cache is not None: completion_seq_group_output.prompt_logprobs = \ - group_prompt_logprobs if logprobs_are_requested else None + group_prompt_logprobs if any_logprobs_are_requested else None output.outputs.append(completion_seq_group_output) else: output.outputs.append( CompletionSequenceGroupOutput( seq_outputs, (group_prompt_logprobs - if logprobs_are_requested else None))) + if any_logprobs_are_requested else None))) assert len(output.outputs) > 0