Skip to content

Commit

Permalink
Change tested model (trained), now the tests are more reliable
Browse files Browse the repository at this point in the history
  • Loading branch information
mzusman committed Jul 24, 2024
1 parent 7d4fed3 commit bda9876
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions tests/models/test_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from vllm.sampling_params import SamplingParams
from vllm.worker.model_runner import _get_graph_batch_size

MODELS = ["ai21labs/Jamba-tiny-random"]
MODELS = ["pszemraj/jamba-900M-v0.13-KIx2"]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [20])
def test_models(
hf_runner,
Expand All @@ -18,8 +18,6 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
assert dtype == "float"

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
Expand Down Expand Up @@ -139,7 +137,7 @@ def test_models_preemption_recompute(
max_tokens: int,
) -> None:
# Tests that outputs are identical with and w/o preemtions (recompute)
assert dtype == "float"
# assert dtype == "float"

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_model.model.llm_engine.scheduler[
Expand All @@ -160,7 +158,7 @@ def test_models_preemption_recompute(


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
vllm_runner,
model: str,
Expand All @@ -182,7 +180,29 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_cleanup_upon_aborted_requests(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is for verifying that the Jamba inner state management doesn't
# collapse in case where the number of incoming requests and
# finished_requests_ids is larger than the maximum mamba block capacity.
# This could generally happen due to the fact that Jamba does support
# statelessness mechanism where it can cleanup new incoming requests in
# a single step.
try:
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
except ValueError:
pytest.fail("Jamba inner state wasn't cleaned up properly between"
"steps finished requests registered unnecessarily ")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_state_cleanup(
vllm_runner,
model: str,
Expand All @@ -201,7 +221,7 @@ def test_state_cleanup(


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_model_print(
vllm_runner,
model: str,
Expand Down

0 comments on commit bda9876

Please sign in to comment.