diff --git a/third_party/flash-attention b/third_party/flash-attention index bdf733be55..f86e3dd919 160000 --- a/third_party/flash-attention +++ b/third_party/flash-attention @@ -1 +1 @@ -Subproject commit bdf733be55f0b323a8cf7cc6745a81c3f43cd7f0 +Subproject commit f86e3dd9192e41dee3814c4cfd8bbce4792e6753 diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index f598dbb74d..736a413b4f 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -63,7 +63,7 @@ FLASH_VERSION = flash_attn.__version__ FLASH_VER_MIN = (2, 6, 3) - FLASH_VER_LAST = (2, 6, 3) # last supported, inclusive + FLASH_VER_LAST = (2, 7, 2) # last supported, inclusive flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) if ( flash_ver_parsed < FLASH_VER_MIN or flash_ver_parsed > FLASH_VER_LAST diff --git a/xformers/ops/fmha/flash3.py b/xformers/ops/fmha/flash3.py index 88e8e29fce..900b9fdd53 100644 --- a/xformers/ops/fmha/flash3.py +++ b/xformers/ops/fmha/flash3.py @@ -50,6 +50,7 @@ # We end up here is arch is not 90a _C_flashattention3 = None + if _C_flashattention3 is not None: # returns: out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p @torch.library.custom_op( @@ -68,7 +69,9 @@ def mha_fwd( softmax_scale: float, is_causal: bool, ) -> Tuple[torch.Tensor, torch.Tensor,]: + win_left = win_right = -1 if cu_seqlens_q is None: + use_gqa_packing = False assert cu_seqlens_k is None assert seqused_k is None ( @@ -80,9 +83,21 @@ def mha_fwd( softmax_lse, p, ) = _C_flashattention3.fwd( - query, key, value, None, softmax_scale, None, None, None, is_causal + query, + key, + value, + None, + softmax_scale, + None, + None, + None, + is_causal, + win_left, + win_right, + use_gqa_packing, ) else: + seqused_q = block_table = None out, q, k, v, out_padded, softmax_lse = _C_flashattention3.varlen_fwd( query, key, @@ -90,11 +105,15 @@ def mha_fwd( None, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, + block_table, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal, + win_left, + win_right, ) return out, softmax_lse @@ -157,6 +176,8 @@ def mha_bwd( softmax_scale: float, is_causal: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + win_left = win_right = -1 + seqused_q = seqused_k = None dq, dk, dv = _create_dq_dk_dv(grads_share_storage, query, key, value) is_deterministic = False if cu_seqlens_q is None: @@ -173,6 +194,8 @@ def mha_bwd( dv, softmax_scale, is_causal, + win_left, + win_right, is_deterministic, ) else: @@ -188,10 +211,14 @@ def mha_bwd( dv, cu_seqlens_q, cu_seqlens_k, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal, + win_left, + win_right, is_deterministic, ) return dq, dk, dv