From 4d56c719bc7754da7042c35c22bf8c5cce2c1384 Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 27 Feb 2024 16:21:32 +0900 Subject: [PATCH] Use Flash attention instead of CUTLASS FMHA for Gemma prefill (#49) --- python/tvm/contrib/cutlass/gen_tensor_op.py | 20 ++++++++++++++++--- src/runtime/contrib/cutlass/flash_decoding.cu | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 298d7895722c..c65e8d1ef649 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -764,12 +764,18 @@ def get_batch_on_arg(arg_name, arg_shape): is_mqa = annotations["num_q_heads"] != annotations["num_kv_heads"] - use_flash = ( + is_flash_supported = ( annotations["ret_dtype"] == "float16" and "bias" not in attrs and int(attrs["head_dim"]) <= 256 and int(attrs["head_dim"]) % 8 == 0 and int(attrs["head_dim"]) == int(attrs["head_dim_value"]) + # Flash v2 is currently not supported for sm < 80 + and int(annotations["arch"]) >= 80 + ) + + use_flash = ( + is_flash_supported # For the causal case (custom mask = "BottomRight"), only use flash for multi-query # attention workloads. Otherwise, CUTLASS fMHA seems faster for causal attention # with a single query. @@ -779,10 +785,18 @@ def get_batch_on_arg(arg_name, arg_shape): or (int(annotations["custom_mask_type"]) == 2 and is_mqa) or (int(annotations["custom_mask_type"]) == 2 and "window_size" in annotations) ) - # Flash v2 is currently not supported for sm < 80 - and int(annotations["arch"]) >= 80 ) + if ( + is_flash_supported + and not use_flash + and int(annotations["custom_mask_type"]) == 2 + and int(attrs["head_dim"]) == 256 + ): + # This is a workaround for prefill inference in Gemma. CUTLASS FMHA raises an error + # at runtime when the batch size is big. Flash Attention has no issue. + use_flash = True + # See https://github.com/Dao-AILab/flash-attention/blob/ # 92dd5703ecdb99aa4a4aee9817f28557907403a2/csrc/flash_attn/flash_api.cpp#L111-L116 if "window_size" in annotations: diff --git a/src/runtime/contrib/cutlass/flash_decoding.cu b/src/runtime/contrib/cutlass/flash_decoding.cu index 8c2f2a095b52..592b576b9077 100644 --- a/src/runtime/contrib/cutlass/flash_decoding.cu +++ b/src/runtime/contrib/cutlass/flash_decoding.cu @@ -154,7 +154,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.flash_attn.flash_decoding_with_paged_kvcache") int max_num_blocks_per_seq = block_tables->shape[1]; float softmax_scale = 1.0 / sqrt(static_cast(head_dim)); - ICHECK(block_size % 128 == 0) << "Block size needs to be a multiple of 128."; + ICHECK(block_size % 64 == 0) << "Block size needs to be a multiple of 64."; auto block_table_ptr = static_cast(block_tables->data); auto seqlens_k_ptr = static_cast(context_lens->data);