Skip to content

Commit

Permalink
Bump FA to 2.7.2 (fairinternal/xformers#1269)
Browse files Browse the repository at this point in the history
* Bump FA to 2.7.2

* Update flash.py

__original_commit__ = fairinternal/xformers@6d989fa
  • Loading branch information
danthe3rd authored and xFormers Bot committed Dec 11, 2024
1 parent 291ddf6 commit 839c4ec
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
2 changes: 1 addition & 1 deletion third_party/flash-attention
Submodule flash-attention updated 130 files
2 changes: 1 addition & 1 deletion xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@

FLASH_VERSION = flash_attn.__version__
FLASH_VER_MIN = (2, 6, 3)
FLASH_VER_LAST = (2, 6, 3) # last supported, inclusive
FLASH_VER_LAST = (2, 7, 2) # last supported, inclusive
flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
if (
flash_ver_parsed < FLASH_VER_MIN or flash_ver_parsed > FLASH_VER_LAST
Expand Down
29 changes: 28 additions & 1 deletion xformers/ops/fmha/flash3.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
# We end up here is arch is not 90a
_C_flashattention3 = None


if _C_flashattention3 is not None:
# returns: out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p
@torch.library.custom_op(
Expand All @@ -68,7 +69,9 @@ def mha_fwd(
softmax_scale: float,
is_causal: bool,
) -> Tuple[torch.Tensor, torch.Tensor,]:
win_left = win_right = -1
if cu_seqlens_q is None:
use_gqa_packing = False
assert cu_seqlens_k is None
assert seqused_k is None
(
Expand All @@ -80,21 +83,37 @@ def mha_fwd(
softmax_lse,
p,
) = _C_flashattention3.fwd(
query, key, value, None, softmax_scale, None, None, None, is_causal
query,
key,
value,
None,
softmax_scale,
None,
None,
None,
is_causal,
win_left,
win_right,
use_gqa_packing,
)
else:
seqused_q = block_table = None
out, q, k, v, out_padded, softmax_lse = _C_flashattention3.varlen_fwd(
query,
key,
value,
None,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
block_table,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
is_causal,
win_left,
win_right,
)
return out, softmax_lse

Expand Down Expand Up @@ -157,6 +176,8 @@ def mha_bwd(
softmax_scale: float,
is_causal: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
win_left = win_right = -1
seqused_q = seqused_k = None
dq, dk, dv = _create_dq_dk_dv(grads_share_storage, query, key, value)
is_deterministic = False
if cu_seqlens_q is None:
Expand All @@ -173,6 +194,8 @@ def mha_bwd(
dv,
softmax_scale,
is_causal,
win_left,
win_right,
is_deterministic,
)
else:
Expand All @@ -188,10 +211,14 @@ def mha_bwd(
dv,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
is_causal,
win_left,
win_right,
is_deterministic,
)
return dq, dk, dv
Expand Down

0 comments on commit 839c4ec

Please sign in to comment.