Skip to content

Commit

Permalink
Remove n > 1 test for now, need to check why it fails on L4
Browse files Browse the repository at this point in the history
  • Loading branch information
mzusman committed Jul 28, 2024
1 parent aaaacbb commit e36ee93
Showing 1 changed file with 0 additions and 37 deletions.
37 changes: 0 additions & 37 deletions tests/models/test_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,43 +64,6 @@ def test_batching(
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [15])
def test_n_lt_1(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
# assert dtype == "float"

with vllm_runner(model, dtype=dtype) as vllm_model:
for_loop_outputs = []
for _ in range(10):
for_loop_outputs.append(
vllm_model.generate_greedy([example_prompts[1]],
max_tokens)[0])
sampling_params = SamplingParams(n=10,
temperature=0.001,
seed=0,
max_tokens=max_tokens)
n_lt_1_outputs = vllm_model.generate([example_prompts[1]],
sampling_params)
token_ids, texts = n_lt_1_outputs[0]
n_lt_1_outputs = [(token_id, text)
for token_id, text in zip(token_ids, texts)]

check_outputs_equal(
outputs_0_lst=n_lt_1_outputs,
outputs_1_lst=for_loop_outputs,
name_0="vllm_n_lt_1_outputs",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [20])
Expand Down

0 comments on commit e36ee93

Please sign in to comment.