Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable using fp8 kv and prefix caching with chunked prefill #668

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aphrodite/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def forward(
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
Expand All @@ -469,6 +470,8 @@ def forward(
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
k_scale,
v_scale,
)

if decode_meta := attn_metadata.decode_metadata:
Expand Down
3 changes: 3 additions & 0 deletions aphrodite/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ def forward(
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
Expand All @@ -612,6 +613,8 @@ def forward(
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window,
k_scale,
v_scale,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
Expand Down
1 change: 1 addition & 0 deletions aphrodite/attention/ops/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache_dtype: str,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
Expand Down
6 changes: 6 additions & 0 deletions aphrodite/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache_dtype: str,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
Expand All @@ -203,13 +204,16 @@ def forward_prefix(
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
k_scale: float,
v_scale: float,
) -> torch.Tensor:
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
kv_cache_dtype,
key_cache,
value_cache,
block_tables,
Expand All @@ -218,6 +222,8 @@ def forward_prefix(
seq_lens_tensor,
context_lens,
max_query_len,
k_scale,
v_scale,
alibi_slopes,
sliding_window,
)
Expand Down
90 changes: 71 additions & 19 deletions aphrodite/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def _fwd_kernel(
V_cache,
B_Loc,
sm_scale,
k_scale,
v_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
Expand Down Expand Up @@ -117,10 +119,15 @@ def _fwd_kernel(
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]
k_load = tl.load(K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]

if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * k_scale).to(q.dtype)
else:
k = k_load

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
qk += tl.dot(q, k)
Expand Down Expand Up @@ -161,10 +168,14 @@ def _fwd_kernel(
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0) # [N,D]
v_load = tl.load(V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0) # [N,D]
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * v_scale).to(q.dtype)
else:
v = v_load

p = p.to(v.dtype)
acc += tl.dot(p, v)
Expand Down Expand Up @@ -442,6 +453,8 @@ def _fwd_kernel_alibi(
V_cache,
B_Loc,
sm_scale,
k_scale,
v_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
Expand Down Expand Up @@ -537,10 +550,15 @@ def _fwd_kernel_alibi(
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]
k_load = tl.load(K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]

if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * k_scale).to(q.dtype)
else:
k = k_load

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
Expand Down Expand Up @@ -573,10 +591,14 @@ def _fwd_kernel_alibi(
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0)
v_load = tl.load(V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0)
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * v_scale).to(q.dtype)
else:
v = v_load

p = p.to(v.dtype)
acc += tl.dot(p, v, allow_tf32=False)
Expand Down Expand Up @@ -675,18 +697,45 @@ def context_attention_fwd(q,
k,
v,
o,
kv_cache_dtype: str,
k_cache,
v_cache,
b_loc,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
k_scale: float = 1.0,
v_scale: float = 1.0,
alibi_slopes=None,
sliding_window=None):

cap = current_platform.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64
NUM_WARPS = 8
if q.dtype is torch.float32:
BLOCK = BLOCK // 2

# Conversion of FP8 Tensor from uint8 storage to
# appropriate torch.dtype for interpretation by Triton
if "fp8" in kv_cache_dtype:
assert (k_cache.dtype == torch.uint8)
assert (v_cache.dtype == torch.uint8)

if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2
else:
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)

k_cache = k_cache.view(target_dtype)
v_cache = v_cache.view(target_dtype)

if (k_cache.dtype == torch.uint8
or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
raise ValueError("kv_cache_dtype='auto' unsupported for\
FP8 KV Cache prefill kernel")
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
Expand All @@ -703,7 +752,6 @@ def context_attention_fwd(q,
if sliding_window is None or sliding_window <= 0:
sliding_window = 0

num_warps = 8 if Lk <= 64 else 8
if alibi_slopes is not None:
_fwd_kernel_alibi[grid](
q,
Expand All @@ -713,6 +761,8 @@ def context_attention_fwd(q,
v_cache,
b_loc,
sm_scale,
k_scale,
v_scale,
b_start_loc,
b_seq_len,
b_ctx_len,
Expand Down Expand Up @@ -751,7 +801,7 @@ def context_attention_fwd(q,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_warps=NUM_WARPS,
num_stages=1,
)
return
Expand All @@ -764,6 +814,8 @@ def context_attention_fwd(q,
v_cache,
b_loc,
sm_scale,
k_scale,
v_scale,
b_start_loc,
b_seq_len,
b_ctx_len,
Expand Down Expand Up @@ -801,7 +853,7 @@ def context_attention_fwd(q,
BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
SLIDING_WINDOW=sliding_window,
num_warps=num_warps,
num_warps=NUM_WARPS,
num_stages=1,
)
return
14 changes: 11 additions & 3 deletions aphrodite/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
print_warning_once)
from aphrodite.distributed import get_current_tp_rank_partition_size
from aphrodite.modeling.models import ModelRegistry
from aphrodite.platforms import current_platform
from aphrodite.quantization import QUANTIZATION_METHODS
from aphrodite.transformers_utils.config import get_config, get_hf_text_config

Expand Down Expand Up @@ -589,10 +590,17 @@ def _verify_prefix_caching(self) -> None:
raise NotImplementedError(
"Prefix caching is not supported with sliding window. "
"Run with --disable-sliding-window to use prefix caching.")

if self.cache_dtype == "fp8":
raise NotImplementedError(
"Prefix caching is not supported for fp8 cache_dtype. "
"Run with --kv-cache-dtype auto to use prefix caching.")
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < 89:
raise NotImplementedError(
"FP8 KV cache with prefix caching is only supported on "
"GPUs with compute capability 8.9 or higher (e.g., "
"4090, H100). Your GPU has compute capability "
f"{capability}")


def verify_with_parallel_config(
self,
Expand Down
Loading