diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py index d7e3a2fc4a71..0a5fe19f80ec 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/test_jamba.py @@ -1,5 +1,7 @@ import pytest +from vllm.worker.model_runner import _get_graph_batch_size + MODELS = ["ai21labs/Jamba-tiny-random"] @@ -32,6 +34,32 @@ def test_models( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [20]) +def test_mamba_cache_cg_padding( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # This test is for verifying that mamba cache is padded to CG captured + # batch size. If it's not, a torch RuntimeError will be raised because + # tensor dimensions aren't compatible + while len(example_prompts) == _get_graph_batch_size(len(example_prompts)): + example_prompts.append(example_prompts[0]) + + try: + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + except RuntimeError: + pytest.fail( + "Couldn't run batch size which is not equal to a Cuda Graph " + "captured batch size. " + "Could be related to mamba cache not padded correctly") + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) def test_state_cleanup( diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index bf330c7770d1..4524d8df86b9 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -788,12 +788,12 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): key in kwargs for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - batch_size = len(request_ids_to_seq_ids) + cg_batch_size = input_buffers['input_ids'].shape[0] ( current_mamba_cache, indices, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - batch_size) + cg_batch_size) self.current_indices = indices finished_requests_ids = kwargs["finished_requests_ids"] self._release_mamba_cache(finished_requests_ids)