Skip to content

Commit

Permalink
defer return_hidden_states speculation methods
Browse files Browse the repository at this point in the history
  • Loading branch information
NickLucche committed Oct 14, 2024
1 parent 2b472ea commit 8b88b8a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
8 changes: 8 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,6 +1292,14 @@ def maybe_create_spec_config(
"speculative_model unless the draft model config contains an "
"n_predict parameter.")

if enable_chunked_prefill and draft_hf_config.model_type in [
"medusa", "mlp_speculator", "eagle"
]:
raise ValueError(
"Chunked prefill and hidden-state based draft models are not "
"yet compatible."
)

if typical_acceptance_sampler_posterior_threshold is None:
typical_acceptance_sampler_posterior_threshold = 0.09
if typical_acceptance_sampler_posterior_alpha is None:
Expand Down
4 changes: 0 additions & 4 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,6 @@ def _contract_batch(
else:
all_hidden_states = None

# TODO fix with `return_hidden_states=True` where hidden states are full size,
# and we'll need all indices prior to selecting `do_sample=True`,
# while logits are indexed by `selected_token_indices` True

# Rule out prefills that are in `non_spec_indices` but produce no tokens.
non_spec_indices = [
idx for idx in non_spec_indices
Expand Down
5 changes: 4 additions & 1 deletion vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,9 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
hidden_states = sampler_output.hidden_states
if hidden_states is not None:
# remove hidden_states for prompt tokens
# TODO Enable `return_hidden_states`: prefill chunks hidden states are
# pruned by the logits processor. Also, they should be arranged back into
# full-prefill latent. Address it to enable MLPSpeculator.
if any(seq.is_prompt
for seq in execute_model_req.seq_group_metadata_list):
hidden_states = hidden_states[
Expand Down Expand Up @@ -698,7 +701,7 @@ def _run_speculative_decoding_step(
# TODO skip this if chunking is not enabled
if len(non_spec_indices):
all_hidden_states = proposal_scores.hidden_states
# TODO fix `return_hidden_states`
# TODO fix `return_hidden_states`, same as in `_run_no_spec`
if all_hidden_states is not None:
prefill_hidden_states = all_hidden_states[non_spec_indices]
execute_model_req.previous_hidden_states = prepare_prefill_hidden_states(
Expand Down

0 comments on commit 8b88b8a

Please sign in to comment.