diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py index 9572aac7df6e..6fbe8c11d76f 100644 --- a/tests/spec_decode/e2e/test_logprobs.py +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -22,10 +22,12 @@ }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_model": "JackFram/llama-160m", - "num_speculative_tokens": 3, -}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + "disable_logprobs_during_spec_decoding": False, + }]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( "output_len", @@ -59,10 +61,12 @@ def test_logprobs_equality(baseline_llm_generator, test_llm_generator, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_model": "JackFram/llama-160m", - "num_speculative_tokens": 3, -}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + "disable_logprobs_during_spec_decoding": False, + }]) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("num_logprobs", [6]) @pytest.mark.parametrize( @@ -99,13 +103,16 @@ def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_model": "JackFram/llama-160m", - "num_speculative_tokens": 3, -}, { - "speculative_model": "JackFram/llama-160m", - "num_speculative_tokens": 6, -}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + "disable_logprobs_during_spec_decoding": False, + }, { + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 6, + "disable_logprobs_during_spec_decoding": False, + }]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( "output_len", @@ -143,6 +150,7 @@ def test_logprobs_different_k(baseline_llm_generator, test_llm_generator, [{ "speculative_model": "JackFram/llama-160m", "num_speculative_tokens": 3, + "disable_logprobs_during_spec_decoding": False, # Artificially limit the draft model max model len; this forces vLLM # to skip speculation once the sequences grow beyond 32-k tokens. @@ -181,10 +189,12 @@ def test_logprobs_when_skip_speculation(baseline_llm_generator, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_model": "JackFram/llama-160m", - "num_speculative_tokens": 3, -}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + "disable_logprobs_during_spec_decoding": False, + }]) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize( "output_len", diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py index 1f3219593f96..aa49a3aee62a 100644 --- a/tests/spec_decode/test_dynamic_spec_decode.py +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -32,6 +32,7 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int, scorer_worker=target_worker, spec_decode_sampler=mock_spec_decode_sampler( acceptance_sampler_method), + disable_logprobs=False, metrics_collector=metrics_collector, disable_by_batch_size=disable_by_batch_size) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 0baac32042ef..671c9bef294f 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -381,6 +381,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool, worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, + disable_logprobs=False, metrics_collector=metrics_collector) worker.init_device() @@ -479,7 +480,8 @@ def test_k_equals_zero(k: int, batch_size: int, worker = SpecDecodeWorker( draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) + mock_spec_decode_sampler(acceptance_sampler_method), False, + metrics_collector) seq_group_metadata_list, _, _ = create_batch(batch_size, k, @@ -490,9 +492,10 @@ def test_k_equals_zero(k: int, batch_size: int, out = worker.execute_model(execute_model_req=execute_model_req) assert len(out) == 1, f"expected only one token output when {k=}" - assert out[0].probs is None, "expect gpu tensor references to be None" + assert out[0].sampled_token_probs is None, ( + "expect gpu tensor references to be None") assert out[ - 0].sampled_tokens is None, "expect gpu tensor references to be None" + 0].sampled_token_ids is None, "expect gpu tensor references to be None" draft_worker.execute_model.assert_called_once_with(execute_model_req) target_worker.execute_model.assert_called_once_with(execute_model_req) @@ -524,7 +527,8 @@ def test_empty_input_batch(k: int, batch_size: int, worker = SpecDecodeWorker( draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) + mock_spec_decode_sampler(acceptance_sampler_method), False, + metrics_collector) seq_group_metadata_list, _, _ = create_batch(batch_size, k, @@ -535,9 +539,10 @@ def test_empty_input_batch(k: int, batch_size: int, out = worker.execute_model(execute_model_req=execute_model_req) assert len(out) == 1, f"expected only one token output when {k=}" - assert out[0].probs is None, "expect gpu tensor references to be None" + assert out[0].sampled_token_probs is None, ( + "expect gpu tensor references to be None") assert out[ - 0].sampled_tokens is None, "expect gpu tensor references to be None" + 0].sampled_token_ids is None, "expect gpu tensor references to be None" draft_worker.execute_model.assert_called_once_with(execute_model_req) target_worker.execute_model.assert_called_once_with(execute_model_req) @@ -556,7 +561,7 @@ def test_init_device(acceptance_sampler_method: str): metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, - metrics_collector) + False, metrics_collector) worker.init_device() draft_worker.init_device.assert_called_once() @@ -707,6 +712,7 @@ def test_populate_seq_ids_with_bonus_tokens(): worker = SpecDecodeWorker(draft_worker, target_worker, mock_spec_decode_sampler("rejection_sampler"), + disable_logprobs=False, metrics_collector=metrics_collector) # Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs. # This set includes all sequence IDs in the batch as well as an additional diff --git a/vllm/config.py b/vllm/config.py index 81ef9526c8b9..46528a548de1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -894,6 +894,7 @@ def maybe_create_spec_config( draft_token_acceptance_method: str, typical_acceptance_sampler_posterior_threshold: Optional[float], typical_acceptance_sampler_posterior_alpha: Optional[float], + disable_logprobs: Optional[bool], ) -> Optional["SpeculativeConfig"]: """Create a SpeculativeConfig if possible, else return None. @@ -943,6 +944,11 @@ def maybe_create_spec_config( typical_acceptance_sampler_posterior_alpha (Optional[float]): A scaling factor for the entropy-based threshold in the TypicalAcceptanceSampler. + disable_logprobs (Optional[bool]): If set to True, token log + probabilities are not returned during speculative decoding. + If set to False, token log probabilities are returned + according to the log probability settings in SamplingParams. + If not specified, it defaults to True. Returns: Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if @@ -1055,6 +1061,8 @@ def maybe_create_spec_config( typical_acceptance_sampler_posterior_threshold = 0.09 if typical_acceptance_sampler_posterior_alpha is None: typical_acceptance_sampler_posterior_alpha = 0.3 + if disable_logprobs is None: + disable_logprobs = True return SpeculativeConfig( draft_model_config, @@ -1068,6 +1076,7 @@ def maybe_create_spec_config( typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=\ typical_acceptance_sampler_posterior_alpha, + disable_logprobs=disable_logprobs ) @staticmethod @@ -1152,6 +1161,7 @@ def __init__( draft_token_acceptance_method: str, typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_alpha: float, + disable_logprobs: bool, ): """Create a SpeculativeConfig object. @@ -1178,6 +1188,12 @@ def __init__( typical_acceptance_sampler_posterior_alpha (Optional[float]): A scaling factor for the entropy-based threshold in the TypicalAcceptanceSampler. + disable_logprobs: If set to True, token log probabilities will not + be returned even if requested by sampling parameters. This + reduces latency by skipping logprob calculation in proposal + sampling, target sampling, and after accepted tokens are + determined. If set to False, log probabilities will be + returned. """ self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config @@ -1191,6 +1207,7 @@ def __init__( typical_acceptance_sampler_posterior_threshold self.typical_acceptance_sampler_posterior_alpha = \ typical_acceptance_sampler_posterior_alpha + self.disable_logprobs = disable_logprobs self._verify_args() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 27a051fcbb2e..972d4e0cd994 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -110,6 +110,7 @@ class EngineArgs: typical_acceptance_sampler_posterior_threshold: Optional[float] = None typical_acceptance_sampler_posterior_alpha: Optional[float] = None qlora_adapter_name_or_path: Optional[str] = None + disable_logprobs_during_spec_decoding: Optional[bool] = None otlp_traces_endpoint: Optional[str] = None @@ -592,6 +593,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'to sqrt of --typical-acceptance-sampler-posterior-threshold ' 'i.e. 0.3') + parser.add_argument( + '--disable-logprobs-during-spec-decoding', + type=bool, + default=EngineArgs.disable_logprobs_during_spec_decoding, + help='If set to True, token log probabilities are not returned ' + 'during speculative decoding. If set to False, log probabilities ' + 'are returned according to the settings in SamplingParams. If ' + 'not specified, it defaults to True. Disabling log probabilities ' + 'during speculative decoding reduces latency by skipping logprob ' + 'calculation in proposal sampling, target sampling, and after ' + 'accepted tokens are determined.') + parser.add_argument('--model-loader-extra-config', type=nullable_str, default=EngineArgs.model_loader_extra_config, @@ -736,6 +749,7 @@ def create_engine_config(self, ) -> EngineConfig: typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=self. typical_acceptance_sampler_posterior_alpha, + disable_logprobs=self.disable_logprobs_during_spec_decoding, ) scheduler_config = SchedulerConfig( diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 2fafc1134b1c..8cf0aa5b8981 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -14,7 +14,7 @@ TypicalAcceptanceSampler) from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SamplerOutput, SequenceGroupMetadata, - get_all_seq_ids_and_request_ids) + get_all_seq_ids, get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, @@ -26,6 +26,7 @@ from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker +from vllm.spec_decode.target_model_runner import TargetModelRunner from vllm.spec_decode.util import (create_sequence_group_output, get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, @@ -44,9 +45,15 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": speculative_config: SpeculativeConfig = kwargs.get("speculative_config") assert speculative_config is not None + draft_worker_kwargs = kwargs.copy() + + kwargs["model_runner_cls"] = TargetModelRunner target_worker = Worker(*args, **kwargs) + # Set the disable_logprobs variable in the TargetModelRunner instance + # as per its value specified in the SpeculativeConfig. + target_worker.model_runner.disable_logprobs =\ + speculative_config.disable_logprobs - draft_worker_kwargs = kwargs.copy() # Override draft-model specific worker args. draft_worker_kwargs.update( model_config=speculative_config.draft_model_config, @@ -67,7 +74,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": typical_acceptance_sampler_posterior_threshold=speculative_config. typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=speculative_config. - typical_acceptance_sampler_posterior_alpha) + typical_acceptance_sampler_posterior_alpha, + disable_logprobs=speculative_config.disable_logprobs) return spec_decode_worker @@ -107,6 +115,7 @@ def create_worker( draft_token_acceptance_method: str, typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_alpha: float, + disable_logprobs: bool, ) -> "SpecDecodeWorker": allow_zero_draft_token_step = True @@ -161,6 +170,7 @@ def create_worker( return SpecDecodeWorker( proposer_worker, scorer_worker, + disable_logprobs=disable_logprobs, disable_by_batch_size=disable_by_batch_size, spec_decode_sampler=spec_decode_sampler, allow_zero_draft_token_step=allow_zero_draft_token_step) @@ -170,6 +180,7 @@ def __init__( proposer_worker: ProposerWorkerBase, scorer_worker: WorkerBase, spec_decode_sampler: SpecDecodeBaseSampler, + disable_logprobs: bool, metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, allow_zero_draft_token_step: Optional[bool] = True, @@ -189,6 +200,9 @@ def __init__( types of sampler namely RejectionSampler and TypicalAcceptanceSampler. 'spec_decode_sampler' is either an instance of RejectionSampler or TypicalAcceptanceSampler. + disable_logprobs: If set to True, token log probabilities will + not be output in both the draft worker and the target worker. + If set to False, log probabilities will be output by both. disable_by_batch_size: If the batch size is larger than this, disable speculative decoding for new incoming requests. metrics_collector: Helper class for collecting metrics; can be set @@ -222,6 +236,7 @@ def __init__( # Hidden states from target model to pass to proposer # in the subsequent step. self.previous_hidden_states: Optional[HiddenStates] = None + self._disable_logprobs = disable_logprobs def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -357,7 +372,6 @@ def execute_model( ) == 0 or disable_all_speculation: return self._run_no_spec(execute_model_req, skip_proposer=disable_all_speculation) - return self._run_speculative_decoding_step(execute_model_req, num_lookahead_slots) @@ -391,6 +405,42 @@ def _maybe_disable_speculative_tokens( # this state within spec decode worker. seq_group_metadata.num_speculative_tokens = 0 + def _serialize_sampler_output_no_logprobs( + self, execute_model_req: ExecuteModelRequest, + sampler_output: SamplerOutput) -> SamplerOutput: + """ + Creates and returns a `SamplerOutput` with only the sampled token IDs + being serialized to CPU & populated in `CompletionSequenceGroupOutput`. + All other parameters in `CompletionSequenceGroupOutput` related to log + probabilities are skipped. + + Args: + execute_model_req (ExecuteModelRequest): The model request that + was executed. + sampler_output (SamplerOutput): The output from the sampler with + only GPU tensors populated. + + Returns: + SamplerOutput: A new `SamplerOutput` instance containing a list of + `CompletionSequenceGroupOutput` objects with only sampled token + IDs populated. + """ + seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list) + sampled_token_ids_list = sampler_output.sampled_token_ids.tolist() + completion_seq_group_output_list: List[ + CompletionSequenceGroupOutput] = [] + for index, seq_id in enumerate(seq_ids): + completion_seq_group_output_list.append( + create_sequence_group_output( + token_id=sampled_token_ids_list[index][0], + token_id_logprob_rank=-1, + token_id_logprob=0.0, + seq_id=seq_id, + topk_token_ids=[], + topk_logprobs=[], + )) + return SamplerOutput(outputs=completion_seq_group_output_list) + @nvtx_range("spec_decode_worker._run_no_spec") def _run_no_spec(self, execute_model_req: ExecuteModelRequest, skip_proposer: bool) -> List[SamplerOutput]: @@ -417,12 +467,17 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, self.previous_hidden_states.update( execute_model_req.seq_group_metadata_list, hidden_states) + sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( + execute_model_req=execute_model_req, sampler_output=sampler_output) + if self._disable_logprobs else + sampler_output) + # Clear device tensors from sampler output. This reduces communication # overhead when the engine runs in a different process than the workers. - sampler_output.probs = None - sampler_output.sampled_tokens = None + sampler_output.sampled_token_probs = None + sampler_output.sampled_token_ids = None sampler_output.logprobs = None - return [sampler_output] + return [sampler_output_to_return] def _run_non_driver_rank(self) -> bool: """Run proposer and verifier model in non-driver workers. This is used @@ -480,7 +535,6 @@ def _run_speculative_decoding_step( execute_model_req, proposals, ) - accepted_token_ids, target_logprobs = self._verify_tokens( execute_model_req.seq_group_metadata_list, proposal_scores, proposals, execute_model_req.num_lookahead_slots) @@ -601,25 +655,27 @@ def _create_output_sampler_list( the same number of outputs. """ batch_size, num_steps = accepted_token_ids.shape - - # Organize input tensors by step instead of by sequence. - target_logprobs_by_step = target_logprobs.transpose(0, 1) accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1) - - # Get the logprobs/rank of the accepted tokens. - (accepted_token_id_ranks_by_step, - accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs( - logprob_tensor=target_logprobs_by_step, - sampled_token_ids=accepted_token_ids_by_step, - ) - - # Get the top-k logprobs (which may or may not include the logprob of - # the accepted token). - (topk_logprobs_by_step, - topk_indices_by_step) = target_logprobs_by_step.topk( - k=self.scorer_worker.model_config.max_logprobs, - dim=-1, - ) + if self._disable_logprobs: + # We are skipping the logprobs. Hence don't serialize the + # logprobs related tensors from the GPU. Instead create + # empty/dummy lists. + (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step, + topk_logprobs_by_step, topk_indices_by_step) =\ + self._create_dummy_logprob_lists( + batch_size, num_steps, + self.scorer_worker.model_config.max_logprobs) + else: + # Organize input tensors by step instead of by sequence. + target_logprobs_by_step = target_logprobs.transpose(0, 1) + # Serialize all tensors into Python lists. + (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step, + topk_logprobs_by_step, topk_indices_by_step) =\ + self._create_logprob_lists_from_tensors( + target_logprobs_by_step, accepted_token_ids_by_step, + self.scorer_worker.model_config.max_logprobs) # Get the sequence ids and num_logprobs (sampling parameter) in the # batch. @@ -628,14 +684,8 @@ def _create_output_sampler_list( num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list) - # Serialize all tensors to CPU Python lists. + # Serialize tensor to CPU Python list. accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() - accepted_token_id_ranks_by_step = ( - accepted_token_id_ranks_by_step.tolist()) - accepted_token_id_logprobs_by_step = ( - accepted_token_id_logprobs_by_step.tolist()) - topk_logprobs_by_step = topk_logprobs_by_step.tolist() - topk_indices_by_step = topk_indices_by_step.tolist() # Construct the output on a per-step, per-sequence basis. sampler_output_list: List[SamplerOutput] = [] @@ -677,6 +727,108 @@ def _create_output_sampler_list( 0].spec_decode_worker_metrics = maybe_rejsample_metrics return sampler_output_list + def _create_dummy_logprob_lists( + self, + batch_size: int, + num_steps: int, + num_top_k: int, + ) -> Tuple[List[List[int]], List[List[float]], + List[List[List[Optional[float]]]], + List[List[List[Optional[int]]]]]: + """ + Creates and returns four dummy lists representing token probabilities + and their ranks. + + This method initializes and returns: + - The ranks of the accepted tokens, shaped (num_steps, batch_size) + - The log probabilities of the accepted tokens, + shaped (num_steps, batch_size) + - The log probabilities of the top k tokens, + shaped (num_steps, batch_size, num_top_k) + - The token IDs of the top k tokens, + shaped (num_steps, batch_size, num_top_k) + + Args: + batch_size (int): The size of the batch. + num_steps (int): The number of steps in the sequence. + num_top_k (int): The number of top-k token log probabilities to + return. + + Returns: + A tuple containing four dummy lists as described above. + """ + accepted_token_id_ranks_by_step = [[-1] * batch_size + for _ in range(num_steps)] + accepted_token_id_logprobs_by_step = [[0.0] * batch_size + for _ in range(num_steps)] + topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[ + [None] * num_top_k for _ in range(batch_size) + ] for _ in range(num_steps)] + topk_indices_by_step: List[List[List[Optional[int]]]] = [[ + [None] * num_top_k for _ in range(batch_size) + ] for _ in range(num_steps)] + return (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step, topk_logprobs_by_step, + topk_indices_by_step) + + def _create_logprob_lists_from_tensors( + self, + target_logprobs_by_step: torch.Tensor, + accepted_token_ids_by_step: torch.Tensor, + num_top_k: int, + ) -> Tuple[List[List[int]], List[List[float]], + List[List[List[Optional[float]]]], + List[List[List[Optional[int]]]]]: + """ + Creates and returns four lists representing token probabilities and + their ranks. + + This method initializes and returns four lists containing: + - The ranks of the accepted tokens, shaped (num_steps, batch_size) + - The log probabilities of the accepted tokens, + shaped (num_steps, batch_size) + - The log probabilities of the top k tokens, + shaped (num_steps, batch_size, num_top_k) + - The token IDs of the top k tokens, + shaped (num_steps, batch_size, num_top_k) + + Args: + target_logprobs_by_step (torch.Tensor): Tensor representing the + log probabilities of the target model, + shaped (num_steps, batch_size, vocab_size) + accepted_token_ids_by_step (torch.Tensor): Tensor representing + the accepted token_ids, shaped (num_steps, batch_size) + num_top_k (int): The number of top-k token log probabilities to + return. + + Returns: + A tuple containing the lists as described above. + """ + # Serialize all tensors to CPU Python lists. + # Get the logprobs/rank of the accepted tokens. + (accepted_token_id_ranks_by_step_tensor, + accepted_token_id_logprobs_by_step_tensor + ) = get_sampled_token_logprobs( + logprob_tensor=target_logprobs_by_step, + sampled_token_ids=accepted_token_ids_by_step, + ) + # Get the top-k logprobs (which may or may not include the + # logprob of the accepted token). + (topk_logprobs_by_step_tensor, + topk_indices_by_step_tensor) = target_logprobs_by_step.topk( + k=num_top_k, + dim=-1, + ) + accepted_token_id_ranks_by_step = ( + accepted_token_id_ranks_by_step_tensor.tolist()) + accepted_token_id_logprobs_by_step = ( + accepted_token_id_logprobs_by_step_tensor.tolist()) + topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist() + topk_indices_by_step = topk_indices_by_step_tensor.tolist() + return (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step, topk_logprobs_by_step, + topk_indices_by_step) + def _track_finished_requests(self, execute_model_req: ExecuteModelRequest): """ Removes the finished requests and their associated sequence ids from diff --git a/vllm/spec_decode/target_model_runner.py b/vllm/spec_decode/target_model_runner.py new file mode 100644 index 000000000000..957f2f8c8843 --- /dev/null +++ b/vllm/spec_decode/target_model_runner.py @@ -0,0 +1,69 @@ +from typing import List, Optional + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, MultiModalConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig) +from vllm.sequence import SequenceGroupMetadata +from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, + ModelRunner) + + +class TargetModelRunner(ModelRunner): + """Specialized model runner for speculative decoding target model. + In speculative decoding, the log probabilities selected finally may not + be the same ones as selected by the target model sampling. This means + that the time spent in the log probability calculation of the target model + is time wasted, since we calculate log probabilities after deciding which + tokens are accepted. For this reason disabling log probabilities in the + target model will make decode faster. The model runner sets the + SamplingMetadata parameters according to whether log probabilities are + requested or not. + """ + + def __init__(self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + multimodal_config: Optional[MultiModalConfig] = None, + return_hidden_states: bool = False): + # An internal boolean member variable to indicate if token log + # probabilities are needed or not. + self.disable_logprobs = True + super().__init__( + model_config=model_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + cache_config=cache_config, + load_config=load_config, + lora_config=lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + multimodal_config=multimodal_config, + prompt_adapter_config=prompt_adapter_config, + return_hidden_states=return_hidden_states, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForGPUWithSamplingMetadata: + model_input: ModelInputForGPUWithSamplingMetadata = super( + ).prepare_model_input(seq_group_metadata_list, virtual_engine, + finished_requests_ids) + # If token log probabilities is disabled then skip generating sampler + # CPU output. We directly serialize the GPU sampled_token_id tensors + # as needed. If log probabilities is enabled then synchronize all the + # sampling related tensors which includes the logprobs tensors. + model_input.sampling_metadata.skip_sampler_cpu_output = ( + self.disable_logprobs) + return model_input diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 80710419e602..ade546eef264 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import torch @@ -53,8 +53,8 @@ def create_sequence_group_output( token_id_logprob_rank: int, token_id_logprob: float, seq_id: SeqId, - topk_token_ids: List[int], - topk_logprobs: List[float], + topk_token_ids: List[Optional[int]], + topk_logprobs: List[Optional[float]], ) -> CompletionSequenceGroupOutput: """Create a SequenceGroupOutput given the sampling results. @@ -68,7 +68,7 @@ def create_sequence_group_output( """ # vLLM logprobs always include the sampled token. In addition, the user may # request topk-logprobs (where top-k varies per user up to max_logprobs). - logprobs: Dict[int, Logprob] = { + logprobs: Dict[Optional[int], Logprob] = { token_id: Logprob( logprob=token_id_logprob, rank=token_id_logprob_rank,