Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][Model] logits_soft_cap for Gemma2 with flashinfer #6051

Merged
41 changes: 41 additions & 0 deletions tests/models/test_gemma2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Compare the outputs of HF and vLLM for Gemma2 models using greedy sampling.

Run `pytest tests/models/test_gemma2.py`.
"""
import os

import pytest

from .utils import check_logprobs_close

MODELS = ["google/gemma-2-9b"]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
os.environ['VLLM_ATTENTION_BACKEND'] = "FLASHINFER"
# TODO(sang): Sliding window should be tested separately.
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
12 changes: 8 additions & 4 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class FlashInferMetadata(AttentionMetadata):
# The data type of the paged kv cache
data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
# Only used by gemma2 model
logits_soft_cap: Optional[float] = None

def __post_init__(self):
# Refer to
Expand Down Expand Up @@ -269,15 +271,17 @@ def forward(
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
output = prefill_meta.prefill_wrapper.forward(query,
kv_cache,
causal=True)
output = prefill_meta.prefill_wrapper.forward(
query,
kv_cache,
logits_soft_cap=attn_metadata.logits_soft_cap,
causal=True)
else:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
output = attn_metadata.decode_metadata.decode_wrapper.forward(
query,
kv_cache,
sm_scale=self.scale,
)
logits_soft_cap=attn_metadata.logits_soft_cap)
return output.view(num_tokens, hidden_size)
2 changes: 1 addition & 1 deletion vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_attn_backend(
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning(("Flashinfer will be stuck on llma-2-7b,"
" please avoid using Flashinfer as the"
" please avoid using Flashinfer as the "
"backend when running on llma-2-7b."))
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
Expand Down
14 changes: 12 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,16 @@ def _prepare_model_input_tensors(
dtype=torch.long,
device=self.device)

logits_soft_cap = getattr(self.model_config.hf_config,
'final_logit_softcapping', None)
if logits_soft_cap is not None and self.attn_backend.get_name(
) != "flashinfer":
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could I check if logits_soft_cap is supposed to be the attn_logit_softcapping value instead? The two values are different in the Gemma2 config.

"attn_logit_softcapping": 50.0,
"final_logit_softcapping": 30.0,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yongxb Nice catch! final_logit_softcapping is used to cap the final logits before sampling. @LiuXiaoxuanPKU Could you please fix this?

logger.warning("Please use Flashinfer backend for models with"
"logits_soft_cap (i.e., Gemma-2)."
" Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should just raise an exception IMO.


if self.attn_backend.get_name() == "flashinfer":
if len(paged_kv_indptr) > 0:
paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
Expand All @@ -676,7 +686,6 @@ def _prepare_model_input_tensors(

kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
self.model_config.dtype)

attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor,
Expand All @@ -697,7 +706,8 @@ def _prepare_model_input_tensors(
query_start_loc=query_start_loc,
device=self.device,
data_type=kv_cache_dtype,
use_cuda_graph=use_captured_graph)
use_cuda_graph=use_captured_graph,
logits_soft_cap=logits_soft_cap)

else:
attn_metadata = self.attn_backend.make_metadata(
Expand Down
Loading