From 2005bb23947540617392b51c38e11df539681d9a Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Fri, 30 Aug 2024 22:18:50 -0700 Subject: [PATCH] [Bugfix] bugfix and add model test for flashinfer fp8 kv cache. (#8013) --- tests/models/test_fp8kv_flashinfer.py | 96 +++++++++++++++++++++++++++ vllm/attention/backends/flashinfer.py | 18 +++-- 2 files changed, 109 insertions(+), 5 deletions(-) create mode 100644 tests/models/test_fp8kv_flashinfer.py diff --git a/tests/models/test_fp8kv_flashinfer.py b/tests/models/test_fp8kv_flashinfer.py new file mode 100644 index 0000000000000..ff2a44162b6c3 --- /dev/null +++ b/tests/models/test_fp8kv_flashinfer.py @@ -0,0 +1,96 @@ +# flake8: noqa +"""Tests fp8 models against ground truth generation +This verifies the flashinfer backend with fp8 +quantization and fp8 KV Cache without scaling +factors Note: these tests will only pass on H100 GPU. +""" +import os +from typing import List + +import pytest +from transformers import AutoTokenizer + +from tests.quantization.utils import is_quant_method_supported +from vllm import LLM, SamplingParams + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +MAX_MODEL_LEN = 1024 + +MODELS = [ + "nm-testing/Meta-Llama-3-8B-Instruct-FP8", +] + +EXPECTED_STRS_MAP = { + "nm-testing/Meta-Llama-3-8B-Instruct-FP8": { + "auto": [ + 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', + 'In the sterile, metallic halls of the robotics lab, a peculiar phenomenon occurred. Zeta-5', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, mushi o', + ], + "fp8": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', + ] + } +} + + +# This test compares against golden strings for exact match since +# there is no baseline implementation to compare against +# and is unstable w.r.t specifics of the fp8 implementation or +# the hardware being run on. +# No assert to prevent it from breaking the build +@pytest.mark.skipif(not is_quant_method_supported("fp8"), + reason="fp8 is not supported on this GPU type.") +@pytest.mark.parametrize("model_name", MODELS) +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) +@pytest.mark.parametrize("backend", ["XFORMERS", "FLASHINFER"]) +def test_models(example_prompts, model_name, kv_cache_dtype, backend) -> None: + # Note that the golden strings may not work for FLASHINFER Backend. + # The intention is to test the path + os.environ["VLLM_ATTENTION_BACKEND"] = backend + model = LLM(model=model_name, + max_model_len=MAX_MODEL_LEN, + trust_remote_code=True, + quantization="fp8", + kv_cache_dtype=kv_cache_dtype) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + formatted_prompts = [ + tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + tokenize=False, + add_generation_prompt=True) + for prompt in example_prompts + ] + + params = SamplingParams(max_tokens=20, temperature=0) + generations: List[str] = [] + # Note: these need to be run 1 at a time due to numerical precision, + # since the expected strs were generated this way. + for prompt in formatted_prompts: + outputs = model.generate(prompt, params) + generations.append(outputs[0].outputs[0].text) + del model + + print(f"Testing: {model_name} with kv_cache_dtype: {kv_cache_dtype}") + expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype] + for i in range(len(example_prompts)): + generated_str = generations[i] + expected_str = expected_strs[i] + print(f"generated_str\n: {generated_str}") + print(f"expected_str\n: {expected_str}") diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index f554fa2805bd2..aa9d4a71dbf87 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -186,9 +186,13 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): self._graph_decode_workspace_buffer, _indptr_buffer, self._graph_indices_buffer, _last_page_len_buffer, "NHD", use_tensor_cores) + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) - kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.runner.kv_cache_dtype) paged_kv_indptr_tensor_host = torch.arange(0, batch_size + 1, dtype=torch.int32) @@ -349,7 +353,7 @@ def begin_forward(self): self.page_size, # Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE", - ) + data_type=self.data_type) def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None @@ -586,8 +590,12 @@ def build(self, seq_lens: List[int], query_lens: List[int], paged_kv_indptr_tensor = None paged_kv_last_page_len_tensor = None - kv_cache_dtype = get_kv_cache_torch_dtype( - self.runner.kv_cache_dtype, self.runner.model_config.dtype) + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) return FlashInferMetadata( num_prefills=self.num_prefills,