Skip to content

Commit

Permalink
Use Flash attention instead of CUTLASS FMHA for Gemma prefill (apache#49
Browse files Browse the repository at this point in the history
)
  • Loading branch information
masahi authored Feb 27, 2024
1 parent f4b0c28 commit 4d56c71
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
20 changes: 17 additions & 3 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,12 +764,18 @@ def get_batch_on_arg(arg_name, arg_shape):

is_mqa = annotations["num_q_heads"] != annotations["num_kv_heads"]

use_flash = (
is_flash_supported = (
annotations["ret_dtype"] == "float16"
and "bias" not in attrs
and int(attrs["head_dim"]) <= 256
and int(attrs["head_dim"]) % 8 == 0
and int(attrs["head_dim"]) == int(attrs["head_dim_value"])
# Flash v2 is currently not supported for sm < 80
and int(annotations["arch"]) >= 80
)

use_flash = (
is_flash_supported
# For the causal case (custom mask = "BottomRight"), only use flash for multi-query
# attention workloads. Otherwise, CUTLASS fMHA seems faster for causal attention
# with a single query.
Expand All @@ -779,10 +785,18 @@ def get_batch_on_arg(arg_name, arg_shape):
or (int(annotations["custom_mask_type"]) == 2 and is_mqa)
or (int(annotations["custom_mask_type"]) == 2 and "window_size" in annotations)
)
# Flash v2 is currently not supported for sm < 80
and int(annotations["arch"]) >= 80
)

if (
is_flash_supported
and not use_flash
and int(annotations["custom_mask_type"]) == 2
and int(attrs["head_dim"]) == 256
):
# This is a workaround for prefill inference in Gemma. CUTLASS FMHA raises an error
# at runtime when the batch size is big. Flash Attention has no issue.
use_flash = True

# See https://github.com/Dao-AILab/flash-attention/blob/
# 92dd5703ecdb99aa4a4aee9817f28557907403a2/csrc/flash_attn/flash_api.cpp#L111-L116
if "window_size" in annotations:
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/contrib/cutlass/flash_decoding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.flash_attn.flash_decoding_with_paged_kvcache")
int max_num_blocks_per_seq = block_tables->shape[1];
float softmax_scale = 1.0 / sqrt(static_cast<float>(head_dim));

ICHECK(block_size % 128 == 0) << "Block size needs to be a multiple of 128.";
ICHECK(block_size % 64 == 0) << "Block size needs to be a multiple of 64.";

auto block_table_ptr = static_cast<int*>(block_tables->data);
auto seqlens_k_ptr = static_cast<int*>(context_lens->data);
Expand Down

0 comments on commit 4d56c71

Please sign in to comment.