From 445b0d35624eaf8a60c8629c18806234638e3712 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 30 Jun 2024 11:44:25 +0800 Subject: [PATCH] [CI/Build] Reuse code for checking output consistency (#5988) --- .../test_basic_correctness.py | 15 +++---- .../basic_correctness/test_chunked_prefill.py | 15 +++---- tests/basic_correctness/test_preemption.py | 16 ++++---- .../test_basic_distributed_correctness.py | 14 +++---- .../test_chunked_prefill_distributed.py | 15 +++---- tests/models/test_big_models.py | 14 +++---- tests/models/test_llava.py | 18 +++++---- tests/models/test_llava_next.py | 18 +++++---- tests/models/test_models.py | 15 ++++--- tests/models/test_phi3v.py | 18 +++++---- tests/models/utils.py | 40 ++++++++++++++++++- 11 files changed, 122 insertions(+), 76 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 71e7b6b94f84..244004916c64 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -9,6 +9,8 @@ from tests.nm_utils.utils_skip import should_skip_test_group from vllm import LLM +from ..models.utils import check_outputs_equal + if should_skip_test_group(group_name="TEST_BASIC_CORRECTNESS"): pytest.skip( "TEST_BASIC_CORRECTNESS=DISABLE, skipping basic correctness test group", @@ -52,10 +54,9 @@ def test_models( gpu_memory_utilization=0.7) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 8cb033edd25a..bebd088a2280 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -8,6 +8,8 @@ """ import pytest +from ..models.utils import check_outputs_equal + from tests.nm_utils.utils_skip import should_skip_test_group if should_skip_test_group(group_name="TEST_BASIC_CORRECTNESS"): @@ -64,10 +66,9 @@ def test_models( ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index f16c1fb48b83..fa4b76947393 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -13,6 +13,8 @@ from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, ENABLE_ARTIFICIAL_PREEMPT) +from ..models.utils import check_outputs_equal + if should_skip_test_group(group_name="TEST_BASIC_CORRECTNESS"): pytest.skip( "TEST_BASIC_CORRECTNESS=DISABLE, skipping basic correctness test group", @@ -100,13 +102,13 @@ def test_preemption( total_preemption = ( vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + assert ("is preempted by PreemptionMode.RECOMPUTE mode because there " "is not enough KV cache space." in caplog_vllm.text) # Ensure the count bucket of request-level histogram metrics matches diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 3a88a7d110ba..90fd10d7f416 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -22,6 +22,7 @@ import torch from tests.nm_utils.utils_skip import should_skip_test_group +from ..models.utils import check_outputs_equal if should_skip_test_group(group_name="TEST_DISTRIBUTED"): pytest.skip("TEST_DISTRIBUTED=DISABLE, skipping distributed test group", @@ -61,10 +62,9 @@ def test_models( ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index 5e931284f1c4..3686d8975034 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -20,6 +20,8 @@ import pytest import torch +from ..models.utils import check_outputs_equal + from tests.nm_utils.utils_skip import should_skip_test_group if should_skip_test_group(group_name="TEST_DISTRIBUTED"): @@ -72,10 +74,9 @@ def test_models( ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index 962af078618f..7b6b89ebb687 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -11,6 +11,7 @@ import torch from tests.nm_utils.utils_skip import should_skip_test_group +from .utils import check_outputs_equal if should_skip_test_group(group_name="TEST_MODELS"): pytest.skip("TEST_MODELS=DISABLE, skipping model test group", @@ -77,13 +78,12 @@ def test_models( with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) @pytest.mark.skip("Slow and not useful (just prints model).") diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index cfd08c3c2811..cf0e4232d58d 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -7,6 +7,7 @@ from vllm.config import VisionLanguageConfig from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from .utils import check_outputs_equal if should_skip_test_group(group_name="TEST_MODELS"): pytest.skip("TEST_MODELS=DISABLE, skipping model test group", @@ -114,14 +115,15 @@ def run_test( max_tokens, images=vllm_images) - for i in range(len(HF_IMAGE_PROMPTS)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_to_hf_output( - vllm_outputs[i], vlm_config, model_id) - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + hf_outputs, + [ + vllm_to_hf_output(vllm_output, vlm_config, model_id) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) @pytest.mark.parametrize("model_and_config", model_and_vl_config) diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index 4a84458b9684..38f4d9872408 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -7,6 +7,7 @@ from vllm.config import VisionLanguageConfig from ..conftest import IMAGE_ASSETS +from .utils import check_outputs_equal pytestmark = pytest.mark.vlm @@ -122,11 +123,12 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config, max_tokens, images=vllm_images) - for i in range(len(HF_IMAGE_PROMPTS)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_to_hf_output( - vllm_outputs[i], vlm_config, model_id) - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + hf_outputs, + [ + vllm_to_hf_output(vllm_output, vlm_config, model_id) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index efa7eea2e9ae..03630d4ef2cb 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -10,6 +10,7 @@ import pytest from tests.nm_utils.utils_skip import should_skip_test_group +from .utils import check_outputs_equal if should_skip_test_group(group_name="TEST_MODELS"): pytest.skip("TEST_MODELS=DISABLE, skipping model test group", @@ -51,14 +52,12 @@ def test_models( with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") - + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) @pytest.mark.skip("Slow and not useful (just prints model).") @pytest.mark.parametrize("model", MODELS) diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index 1b63be6c15dd..34925efea202 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -8,6 +8,7 @@ from vllm.utils import is_cpu from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from .utils import check_outputs_equal if should_skip_test_group(group_name="TEST_MODELS"): pytest.skip("TEST_MODELS=DISABLE, skipping models test group", @@ -129,14 +130,15 @@ def run_test( max_tokens, images=vllm_images) - for i in range(len(HF_IMAGE_PROMPTS)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_to_hf_output( - vllm_outputs[i], vlm_config, model_id) - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + hf_outputs, + [ + vllm_to_hf_output(vllm_output, vlm_config, model_id) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) # Since we use _attn_implementation="eager" for hf_runner, here is diff --git a/tests/models/utils.py b/tests/models/utils.py index 3e49dfb33117..0d5e304d8446 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,7 +1,43 @@ -def check_logprobs_close(outputs_0_lst, outputs_1_lst, name_0, name_1): - """Compare the logprobs of two sequences generated by different models, +from typing import Dict, List, Tuple + +TokensText = Tuple[List[int], str] + + +def check_outputs_equal(outputs_0_lst: List[TokensText], + outputs_1_lst: List[TokensText], name_0: str, + name_1: str): + """ + Compare the two sequences generated by different models, + which should be equal. + """ + assert len(outputs_0_lst) == len(outputs_1_lst) + + for prompt_idx, (outputs_0, + outputs_1) in enumerate(zip(outputs_0_lst, + outputs_1_lst)): + output_ids_0, output_str_0 = outputs_0 + output_ids_1, output_str_1 = outputs_1 + + assert output_str_0 == output_str_1, (f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + + +TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]] + + +def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], + outputs_1_lst: List[TokensTextLogprobs], name_0: str, + name_1: str): + """ + Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. """ + assert len(outputs_0_lst) == len(outputs_1_lst) + # Loop through responses to each prompt. for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst,