Skip to content

Commit

Permalink
Granite test.
Browse files Browse the repository at this point in the history
  • Loading branch information
shawntan committed Aug 19, 2024
1 parent 9164b8e commit 40b2891
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions tests/models/test_granite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Compare the outputs of HF and vLLM for Granite models using greedy sampling.
Run `pytest tests/models/test_granite.py`.
"""
import pytest

from .utils import check_logprobs_close

MODELS = [
"mayank-mishra/granite-3b-mup",
]


@pytest.mark.parametrize("model", MODELS)
# @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
# TODO(sang): Sliding window should be tested separately.
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
# print("hf_outputs ", hf_outputs)
# print("vllm_outputs", vllm_outputs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)

0 comments on commit 40b2891

Please sign in to comment.