Skip to content

Commit

Permalink
Only test backward if there's no softcapping
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jul 10, 2024
1 parent 908511b commit 3d41db3
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ def test_flash_attn_output(

g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
if kvpacked:
(
dq,
Expand Down Expand Up @@ -1107,7 +1107,7 @@ def test_flash_attn_output(
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)

if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
Expand Down Expand Up @@ -1365,7 +1365,7 @@ def test_flash_attn_varlen_output(
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")

g = torch.randn_like(out)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
if kvpacked:
(
dq_unpad,
Expand Down Expand Up @@ -1424,7 +1424,7 @@ def test_flash_attn_varlen_output(
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)

if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
Expand Down

0 comments on commit 3d41db3

Please sign in to comment.