Skip to content

Commit

Permalink
[Bugfix] bugfix and add model test for flashinfer fp8 kv cache. (vllm…
Browse files Browse the repository at this point in the history
  • Loading branch information
pavanimajety authored Aug 31, 2024
1 parent f5ba231 commit 2005bb2
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 5 deletions.
96 changes: 96 additions & 0 deletions tests/models/test_fp8kv_flashinfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# flake8: noqa
"""Tests fp8 models against ground truth generation
This verifies the flashinfer backend with fp8
quantization and fp8 KV Cache without scaling
factors Note: these tests will only pass on H100 GPU.
"""
import os
from typing import List

import pytest
from transformers import AutoTokenizer

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams

os.environ["TOKENIZERS_PARALLELISM"] = "true"

MAX_MODEL_LEN = 1024

MODELS = [
"nm-testing/Meta-Llama-3-8B-Instruct-FP8",
]

EXPECTED_STRS_MAP = {
"nm-testing/Meta-Llama-3-8B-Instruct-FP8": {
"auto": [
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
'In the sterile, metallic halls of the robotics lab, a peculiar phenomenon occurred. Zeta-5',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, mushi o',
],
"fp8": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o',
]
}
}


# This test compares against golden strings for exact match since
# there is no baseline implementation to compare against
# and is unstable w.r.t specifics of the fp8 implementation or
# the hardware being run on.
# No assert to prevent it from breaking the build
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="fp8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_name", MODELS)
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("backend", ["XFORMERS", "FLASHINFER"])
def test_models(example_prompts, model_name, kv_cache_dtype, backend) -> None:
# Note that the golden strings may not work for FLASHINFER Backend.
# The intention is to test the path
os.environ["VLLM_ATTENTION_BACKEND"] = backend
model = LLM(model=model_name,
max_model_len=MAX_MODEL_LEN,
trust_remote_code=True,
quantization="fp8",
kv_cache_dtype=kv_cache_dtype)

tokenizer = AutoTokenizer.from_pretrained(model_name)
formatted_prompts = [
tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
tokenize=False,
add_generation_prompt=True)
for prompt in example_prompts
]

params = SamplingParams(max_tokens=20, temperature=0)
generations: List[str] = []
# Note: these need to be run 1 at a time due to numerical precision,
# since the expected strs were generated this way.
for prompt in formatted_prompts:
outputs = model.generate(prompt, params)
generations.append(outputs[0].outputs[0].text)
del model

print(f"Testing: {model_name} with kv_cache_dtype: {kv_cache_dtype}")
expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype]
for i in range(len(example_prompts)):
generated_str = generations[i]
expected_str = expected_strs[i]
print(f"generated_str\n: {generated_str}")
print(f"expected_str\n: {expected_str}")
18 changes: 13 additions & 5 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,13 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
use_tensor_cores)
if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.runner.kv_cache_dtype)
else:
kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)

kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.runner.kv_cache_dtype)
paged_kv_indptr_tensor_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
Expand Down Expand Up @@ -349,7 +353,7 @@ def begin_forward(self):
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
)
data_type=self.data_type)

def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
Expand Down Expand Up @@ -586,8 +590,12 @@ def build(self, seq_lens: List[int], query_lens: List[int],
paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None

kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.runner.kv_cache_dtype)
else:
kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)

return FlashInferMetadata(
num_prefills=self.num_prefills,
Expand Down

0 comments on commit 2005bb2

Please sign in to comment.