From 8b88b8a366689458c2c024ce5b5e658f41f0da05 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 14 Oct 2024 10:30:27 +0000 Subject: [PATCH] defer return_hidden_states speculation methods --- vllm/config.py | 8 ++++++++ vllm/spec_decode/batch_expansion.py | 4 ---- vllm/spec_decode/spec_decode_worker.py | 5 ++++- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index f80c51f2d1a7..7672e8b02a2c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1292,6 +1292,14 @@ def maybe_create_spec_config( "speculative_model unless the draft model config contains an " "n_predict parameter.") + if enable_chunked_prefill and draft_hf_config.model_type in [ + "medusa", "mlp_speculator", "eagle" + ]: + raise ValueError( + "Chunked prefill and hidden-state based draft models are not " + "yet compatible." + ) + if typical_acceptance_sampler_posterior_threshold is None: typical_acceptance_sampler_posterior_threshold = 0.09 if typical_acceptance_sampler_posterior_alpha is None: diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 67d6a6ad79b2..5a4b072c68f8 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -195,10 +195,6 @@ def _contract_batch( else: all_hidden_states = None - # TODO fix with `return_hidden_states=True` where hidden states are full size, - # and we'll need all indices prior to selecting `do_sample=True`, - # while logits are indexed by `selected_token_indices` True - # Rule out prefills that are in `non_spec_indices` but produce no tokens. non_spec_indices = [ idx for idx in non_spec_indices diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index d23665751cec..0901d3bdb3f0 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -584,6 +584,9 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, hidden_states = sampler_output.hidden_states if hidden_states is not None: # remove hidden_states for prompt tokens + # TODO Enable `return_hidden_states`: prefill chunks hidden states are + # pruned by the logits processor. Also, they should be arranged back into + # full-prefill latent. Address it to enable MLPSpeculator. if any(seq.is_prompt for seq in execute_model_req.seq_group_metadata_list): hidden_states = hidden_states[ @@ -698,7 +701,7 @@ def _run_speculative_decoding_step( # TODO skip this if chunking is not enabled if len(non_spec_indices): all_hidden_states = proposal_scores.hidden_states - # TODO fix `return_hidden_states` + # TODO fix `return_hidden_states`, same as in `_run_no_spec` if all_hidden_states is not None: prefill_hidden_states = all_hidden_states[non_spec_indices] execute_model_req.previous_hidden_states = prepare_prefill_hidden_states(