diff --git a/examples/offline_inference_mlpspeculator.py b/examples/offline_inference_mlpspeculator.py index 5448ec1f6208c..5dec4a76afb2f 100644 --- a/examples/offline_inference_mlpspeculator.py +++ b/examples/offline_inference_mlpspeculator.py @@ -52,7 +52,6 @@ def time_generation(llm: LLM, prompts: List[str], speculative_model="ibm-fms/llama-13b-accelerator", # These are currently required for MLPSpeculator decoding use_v2_block_manager=True, - enforce_eager=True, ) print("With speculation") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9fdb2ea5dd4e4..ac820bbcbca33 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1020,10 +1020,13 @@ def execute_model( if self.return_hidden_states: # we only need to pass hidden states of most recent token + assert model_input.sampling_metadata is not None + indices = model_input.sampling_metadata.selected_token_indices if model_input.is_prompt: - assert model_input.sampling_metadata is not None - hidden_states = hidden_states.index_select( - 0, model_input.sampling_metadata.selected_token_indices) + hidden_states = hidden_states.index_select(0, indices) + elif decode_meta.use_cuda_graph: + hidden_states = hidden_states[:len(indices)] + output.hidden_states = hidden_states return output