Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[CI/Build] Reuse code for checking output consistency (vllm-project#5988
Browse files Browse the repository at this point in the history
)
  • Loading branch information
DarkLight1337 authored and robertgshaw2-neuralmagic committed Jul 1, 2024
1 parent d0b7111 commit 445b0d3
Show file tree
Hide file tree
Showing 11 changed files with 122 additions and 76 deletions.
15 changes: 8 additions & 7 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from tests.nm_utils.utils_skip import should_skip_test_group
from vllm import LLM

from ..models.utils import check_outputs_equal

if should_skip_test_group(group_name="TEST_BASIC_CORRECTNESS"):
pytest.skip(
"TEST_BASIC_CORRECTNESS=DISABLE, skipping basic correctness test group",
Expand Down Expand Up @@ -52,10 +54,9 @@ def test_models(
gpu_memory_utilization=0.7) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
15 changes: 8 additions & 7 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
"""
import pytest

from ..models.utils import check_outputs_equal

from tests.nm_utils.utils_skip import should_skip_test_group

if should_skip_test_group(group_name="TEST_BASIC_CORRECTNESS"):
Expand Down Expand Up @@ -64,10 +66,9 @@ def test_models(
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
16 changes: 9 additions & 7 deletions tests/basic_correctness/test_preemption.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
ENABLE_ARTIFICIAL_PREEMPT)

from ..models.utils import check_outputs_equal

if should_skip_test_group(group_name="TEST_BASIC_CORRECTNESS"):
pytest.skip(
"TEST_BASIC_CORRECTNESS=DISABLE, skipping basic correctness test group",
Expand Down Expand Up @@ -100,13 +102,13 @@ def test_preemption(
total_preemption = (
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)

assert ("is preempted by PreemptionMode.RECOMPUTE mode because there "
"is not enough KV cache space." in caplog_vllm.text)
# Ensure the count bucket of request-level histogram metrics matches
Expand Down
14 changes: 7 additions & 7 deletions tests/distributed/test_basic_distributed_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch

from tests.nm_utils.utils_skip import should_skip_test_group
from ..models.utils import check_outputs_equal

if should_skip_test_group(group_name="TEST_DISTRIBUTED"):
pytest.skip("TEST_DISTRIBUTED=DISABLE, skipping distributed test group",
Expand Down Expand Up @@ -61,10 +62,9 @@ def test_models(
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
15 changes: 8 additions & 7 deletions tests/distributed/test_chunked_prefill_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import pytest
import torch

from ..models.utils import check_outputs_equal

from tests.nm_utils.utils_skip import should_skip_test_group

if should_skip_test_group(group_name="TEST_DISTRIBUTED"):
Expand Down Expand Up @@ -72,10 +74,9 @@ def test_models(
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
14 changes: 7 additions & 7 deletions tests/models/test_big_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch

from tests.nm_utils.utils_skip import should_skip_test_group
from .utils import check_outputs_equal

if should_skip_test_group(group_name="TEST_MODELS"):
pytest.skip("TEST_MODELS=DISABLE, skipping model test group",
Expand Down Expand Up @@ -77,13 +78,12 @@ def test_models(
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.skip("Slow and not useful (just prints model).")
Expand Down
18 changes: 10 additions & 8 deletions tests/models/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from vllm.config import VisionLanguageConfig

from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_outputs_equal

if should_skip_test_group(group_name="TEST_MODELS"):
pytest.skip("TEST_MODELS=DISABLE, skipping model test group",
Expand Down Expand Up @@ -114,14 +115,15 @@ def run_test(
max_tokens,
images=vllm_images)

for i in range(len(HF_IMAGE_PROMPTS)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_to_hf_output(
vllm_outputs[i], vlm_config, model_id)
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
check_outputs_equal(
hf_outputs,
[
vllm_to_hf_output(vllm_output, vlm_config, model_id)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model_and_config", model_and_vl_config)
Expand Down
18 changes: 10 additions & 8 deletions tests/models/test_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from vllm.config import VisionLanguageConfig

from ..conftest import IMAGE_ASSETS
from .utils import check_outputs_equal

pytestmark = pytest.mark.vlm

Expand Down Expand Up @@ -122,11 +123,12 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
max_tokens,
images=vllm_images)

for i in range(len(HF_IMAGE_PROMPTS)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_to_hf_output(
vllm_outputs[i], vlm_config, model_id)
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
check_outputs_equal(
hf_outputs,
[
vllm_to_hf_output(vllm_output, vlm_config, model_id)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
15 changes: 7 additions & 8 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest

from tests.nm_utils.utils_skip import should_skip_test_group
from .utils import check_outputs_equal

if should_skip_test_group(group_name="TEST_MODELS"):
pytest.skip("TEST_MODELS=DISABLE, skipping model test group",
Expand Down Expand Up @@ -51,14 +52,12 @@ def test_models(
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)

@pytest.mark.skip("Slow and not useful (just prints model).")
@pytest.mark.parametrize("model", MODELS)
Expand Down
18 changes: 10 additions & 8 deletions tests/models/test_phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm.utils import is_cpu

from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_outputs_equal

if should_skip_test_group(group_name="TEST_MODELS"):
pytest.skip("TEST_MODELS=DISABLE, skipping models test group",
Expand Down Expand Up @@ -129,14 +130,15 @@ def run_test(
max_tokens,
images=vllm_images)

for i in range(len(HF_IMAGE_PROMPTS)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_to_hf_output(
vllm_outputs[i], vlm_config, model_id)
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
check_outputs_equal(
hf_outputs,
[
vllm_to_hf_output(vllm_output, vlm_config, model_id)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)


# Since we use _attn_implementation="eager" for hf_runner, here is
Expand Down
40 changes: 38 additions & 2 deletions tests/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,43 @@
def check_logprobs_close(outputs_0_lst, outputs_1_lst, name_0, name_1):
"""Compare the logprobs of two sequences generated by different models,
from typing import Dict, List, Tuple

TokensText = Tuple[List[int], str]


def check_outputs_equal(outputs_0_lst: List[TokensText],
outputs_1_lst: List[TokensText], name_0: str,
name_1: str):
"""
Compare the two sequences generated by different models,
which should be equal.
"""
assert len(outputs_0_lst) == len(outputs_1_lst)

for prompt_idx, (outputs_0,
outputs_1) in enumerate(zip(outputs_0_lst,
outputs_1_lst)):
output_ids_0, output_str_0 = outputs_0
output_ids_1, output_str_1 = outputs_1

assert output_str_0 == output_str_1, (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")


TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]]


def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
outputs_1_lst: List[TokensTextLogprobs], name_0: str,
name_1: str):
"""
Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
"""
assert len(outputs_0_lst) == len(outputs_1_lst)

# Loop through responses to each prompt.
for prompt_idx, (outputs_0,
outputs_1) in enumerate(zip(outputs_0_lst,
Expand Down

0 comments on commit 445b0d3

Please sign in to comment.