Skip to content

Commit

Permalink
[Bugfix] Remove hardcoded head_size=256 for Deepseek v2 and v3 (vll…
Browse files Browse the repository at this point in the history
…m-project#12067)

Signed-off-by: Isotr0py <2037008807@qq.com>
  • Loading branch information
Isotr0py committed Feb 2, 2025
1 parent 409b228 commit c993711
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 40 deletions.
6 changes: 3 additions & 3 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing

# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 120, 256]
# This should be sync with get_supported_head_sizes() in
# vllm.attention.ops.paged_attn.PagedAttention
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]

BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
Expand Down
9 changes: 6 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,9 +733,12 @@ def get_head_size(self) -> int:
if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
in ('deepseek_v2', 'deepseek_v3')):
# FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256
return 256
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
0)
qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim",
0)
if qk_rope_head_dim and qk_nope_head_dim:
return qk_rope_head_dim + qk_nope_head_dim

if self.is_attention_free:
return 0
Expand Down
24 changes: 7 additions & 17 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,8 @@ def __init__(
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale

# self.attn = Attention(self.num_heads,
# self.qk_head_dim,
# self.scaling,
# num_kv_heads=self.num_heads)

# TODO, support head_size 192
self.attn = Attention(self.num_local_heads,
256,
self.qk_head_dim,
self.scaling,
num_kv_heads=self.num_local_heads,
cache_config=cache_config,
Expand Down Expand Up @@ -319,18 +313,14 @@ def forward(
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim:] = k_pe
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
# padding value to qk_head_dim for alignment
v = torch.nn.functional.pad(
v, [0, self.qk_head_dim - self.v_head_dim],
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = attn_output.view(
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
-1, self.num_local_heads,
self.qk_head_dim)[..., :self.v_head_dim].reshape(
-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
return output
Expand Down
24 changes: 7 additions & 17 deletions vllm/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,8 @@ def __init__(
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale

# self.attn = Attention(self.num_heads,
# self.qk_head_dim,
# self.scaling,
# num_kv_heads=self.num_heads)

# TODO, support head_size 192
self.attn = Attention(self.num_local_heads,
256,
self.qk_head_dim,
self.scaling,
num_kv_heads=self.num_local_heads,
cache_config=cache_config,
Expand Down Expand Up @@ -326,18 +320,14 @@ def forward(
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim:] = k_pe
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
# padding value to qk_head_dim for alignment
v = torch.nn.functional.pad(
v, [0, self.qk_head_dim - self.v_head_dim],
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = attn_output.view(
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
-1, self.num_local_heads,
self.qk_head_dim)[..., :self.v_head_dim].reshape(
-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
return output
Expand Down

0 comments on commit c993711

Please sign in to comment.