Skip to content

Commit

Permalink
[Bugfix] Mamba cache Cuda Graph padding (#6214)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomeras91 authored Jul 8, 2024
1 parent 185ad31 commit ddc369f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
28 changes: 28 additions & 0 deletions tests/models/test_jamba.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest

from vllm.worker.model_runner import _get_graph_batch_size

MODELS = ["ai21labs/Jamba-tiny-random"]


Expand Down Expand Up @@ -32,6 +34,32 @@ def test_models(
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [20])
def test_mamba_cache_cg_padding(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
example_prompts.append(example_prompts[0])

try:
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
except RuntimeError:
pytest.fail(
"Couldn't run batch size which is not equal to a Cuda Graph "
"captured batch size. "
"Could be related to mamba cache not padded correctly")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup(
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,12 +788,12 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
batch_size = len(request_ids_to_seq_ids)
cg_batch_size = input_buffers['input_ids'].shape[0]
(
current_mamba_cache,
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
batch_size)
cg_batch_size)
self.current_indices = indices
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
Expand Down

0 comments on commit ddc369f

Please sign in to comment.