Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][Model] logits_soft_cap for Gemma2 with flashinfer #6051

Merged
7 changes: 5 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,15 @@ steps:

- label: Kernels Test %N
#mirror_hardwares: [amd]
command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4

- label: Models Test
#mirror_hardwares: [amd]
commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- pytest -v -s models -m \"not vlm\"

- label: Vision Language Models Test
Expand Down Expand Up @@ -234,7 +237,7 @@ steps:
- pytest -v -s distributed/test_custom_all_reduce.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s -x lora/test_mixtral.py
248 changes: 248 additions & 0 deletions tests/kernels/test_flashinfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
from typing import List, Optional, Tuple

import flashinfer
import pytest
import torch

NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.


def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: List[int],
kv_lens: List[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape

outputs: List[torch.Tensor] = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx:start_idx + query_len]
q *= scale

num_kv_blocks = (kv_len + block_size - 1) // block_size
block_indices = block_tables[i, :num_kv_blocks]

k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
k = k[:kv_len]
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
v = v[:kv_len]

if q.shape[1] != k.shape[1]:
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k).float()
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(empty_mask,
diagonal=kv_len -
(query_len + sliding_window) +
1).bool().logical_not()
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(attn, dim=-1).to(v.dtype)
out = torch.einsum("hqk,khd->qhd", attn, v)

outputs.append(out)
start_idx += query_len

return torch.cat(outputs, dim=0)


@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@torch.inference_mode
def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
num_heads: Tuple[int,
int], head_size: int,
dtype: torch.dtype, block_size: int,
soft_cap: Optional[float]) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5

query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_value_cache = torch.randn(NUM_BLOCKS,
2,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)

kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
wrapper.begin_forward(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
data_type=dtype)

output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap)

ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"


@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@torch.inference_mode
def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
head_size: int, dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float]) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5

query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
key_value_cache = torch.randn(NUM_BLOCKS,
2,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)

# Normalize the scale of the key and value caches to mitigate
# numerical instability.
key_cache /= head_size**0.5
value_cache /= head_size**0.5

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

qo_indptr = [0]
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
qo_indptr.append(qo_indptr[-1] + query_lens[i])

qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD")
wrapper.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
)

output = wrapper.forward(
query,
key_value_cache,
logits_soft_cap=soft_cap,
)

ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
12 changes: 8 additions & 4 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class FlashInferMetadata(AttentionMetadata):
# The data type of the paged kv cache
data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
# Only used by gemma2 model
logits_soft_cap: Optional[float] = None

def __post_init__(self):
# Refer to
Expand Down Expand Up @@ -271,15 +273,17 @@ def forward(
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
output = prefill_meta.prefill_wrapper.forward(query,
kv_cache,
causal=True)
output = prefill_meta.prefill_wrapper.forward(
query,
kv_cache,
logits_soft_cap=attn_metadata.logits_soft_cap,
causal=True)
else:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
output = attn_metadata.decode_metadata.decode_wrapper.forward(
query,
kv_cache,
sm_scale=self.scale,
)
logits_soft_cap=attn_metadata.logits_soft_cap)
return output.view(num_tokens, hidden_size)
6 changes: 3 additions & 3 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def get_attn_backend(
return IpexAttnBackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning(("Flashinfer will be stuck on llma-2-7b,"
" please avoid using Flashinfer as the"
"backend when running on llma-2-7b."))
logger.warning(("Flashinfer will be stuck on llama-2-7b,"
" please avoid using Flashinfer as the "
"backend when running on llama-2-7b."))
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
elif backend == _Backend.PALLAS:
Expand Down
7 changes: 0 additions & 7 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once

from .interfaces import SupportsLoRA

Expand Down Expand Up @@ -137,12 +136,6 @@ def __init__(self,
dtype=torch.get_default_dtype(),
)

if self.config.attn_logit_softcapping is not None:
print_warning_once(
"Gemma 2 normally uses attention logit soft-capping; "
"soft-capping is currently incompatible with the flash "
"attention kernels, so vLLM removes it to enable speed and "
"efficiency gains of flash attention.")
# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
# odd layer, vLLM currently ignores it and uses global attention for
# all layers.
Expand Down
19 changes: 15 additions & 4 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
Expand Down Expand Up @@ -683,6 +683,16 @@ def _prepare_model_input_tensors(
dtype=torch.long,
device=self.device)

logits_soft_cap = getattr(self.model_config.hf_config,
'attn_logit_softcapping', None)
if logits_soft_cap is not None and self.attn_backend.get_name(
) != "flashinfer":
raise ValueError("Please use Flashinfer backend for models with"
"logits_soft_cap (i.e., Gemma-2)."
" Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")

if self.attn_backend.get_name() == "flashinfer":
if len(paged_kv_indptr) > 0:
paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
Expand All @@ -700,7 +710,6 @@ def _prepare_model_input_tensors(

kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
self.model_config.dtype)

attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor,
Expand All @@ -721,7 +730,8 @@ def _prepare_model_input_tensors(
query_start_loc=query_start_loc,
device=self.device,
data_type=kv_cache_dtype,
use_cuda_graph=use_captured_graph)
use_cuda_graph=use_captured_graph,
logits_soft_cap=logits_soft_cap)

else:
attn_metadata = self.attn_backend.make_metadata(
Expand Down Expand Up @@ -1196,7 +1206,8 @@ def execute_model(
if model_input.attn_metadata.use_cuda_graph:
batch_size = model_input.input_tokens.shape[0]
model_input.attn_metadata.decode_wrapper = self.graph_runners[
batch_size].flashinfer_decode_wrapper
model_input.
virtual_engine][batch_size].flashinfer_decode_wrapper
else:
model_input.attn_metadata.decode_wrapper = \
self.flashinfer_decode_wrapper
Expand Down
Loading