diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 82163f970..53f8f444f 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1051,7 +1051,7 @@ def test_flash_attn_output( g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) - if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): + if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0: if kvpacked: ( dq, @@ -1107,7 +1107,7 @@ def test_flash_attn_output( if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) - if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): + if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0: assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() @@ -1365,7 +1365,7 @@ def test_flash_attn_varlen_output( print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") g = torch.randn_like(out) - if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): + if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0: if kvpacked: ( dq_unpad, @@ -1424,7 +1424,7 @@ def test_flash_attn_varlen_output( if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) - if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): + if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0: assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()