From bac8718d624c86398bd900f877133661a6907a77 Mon Sep 17 00:00:00 2001 From: danthe3rd Date: Thu, 19 Jan 2023 17:36:56 +0000 Subject: [PATCH] Refactor memeff tests ghstack-source-id: f460edb5208212069da88afd47222057dd0838bf Pull Request resolved: https://github.com/fairinternal/xformers/pull/437 __original_commit__ = fairinternal/xformers@0bf66c6ef480771a48a53856b8f895ee4927f1fb --- tests/test_mem_eff_attention.py | 355 ++++++++++++++++++++------------ 1 file changed, 223 insertions(+), 132 deletions(-) diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index a6b0136a0a..b1057dd1c6 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -14,6 +14,7 @@ import xformers.ops from xformers.ops import fmha +from xformers.ops.fmha.common import AttentionOpBase from .utils import assert_allclose @@ -128,57 +129,59 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op): return shapes -def _generate_op_device_dtype_B_Mq_Mkv_H_K_Kv(ops_list, one_shape_per_op: bool = False): +def _generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv( + ops_list: Sequence[Type[fmha.AttentionOpBase]], max_shapes_per_op: int = 65000 +): + r = random.Random(0) + combination = [] + ids = [] for op in ops_list: + op_count = 0 for shape in generate_test_shapes_B_Mq_Mkv_H_K_Kv(op): has_one = False for device in _devices: if device not in op.SUPPORTED_DEVICES: continue for dtype in op.SUPPORTED_DTYPES: - yield (op, device, dtype, *shape) + bias_type = r.choice(list(op.SUPPORTED_ATTN_BIAS_TYPES)) + # Avoid using too much memory + if bias_type not in [ + type(None), + fmha.attn_bias.LowerTriangularMask, + ]: + B, Mq, Mkv, H, K, Kv = shape + B = min(B, 12) + shape = (B, Mq, Mkv, H, K, Kv) + combination.append((op, device, dtype, bias_type, *shape)) + ids.append( + f"{op.NAME}-{device}-{str(dtype)}-{bias_type.__name__}-{'-'.join([str(s) for s in shape])}" + ) has_one = True - if has_one and one_shape_per_op: + if has_one: + op_count += 1 + if op_count > max_shapes_per_op: break + return { + "argvalues": combination, + "ids": ids, + } -def _gen_ids(op_device_dtype_B_Mq_Mkv_H_K_Kv): - return [ - f"{op.NAME}-{device}-{str(dtype)}-{batch_size}-{q_len}-{kv_len}-{h}-{k}-{kv}" - for ( - op, - device, - dtype, - batch_size, - q_len, - kv_len, - h, - k, - kv, - ) in op_device_dtype_B_Mq_Mkv_H_K_Kv - ] - - -_opFW_device_dtype_B_Mq_Mkv_H_K_Kv = list( - _generate_op_device_dtype_B_Mq_Mkv_H_K_Kv(ALL_FW_OPS) -) -_opFW_device_dtype_B_Mq_Mkv_H_K_Kv_ids = _gen_ids(_opFW_device_dtype_B_Mq_Mkv_H_K_Kv) -_opBW_device_dtype_B_Mq_Mkv_H_K_Kv = list( - _generate_op_device_dtype_B_Mq_Mkv_H_K_Kv(ALL_BW_OPS) +parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = pytest.mark.parametrize( + "opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv", + **_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_FW_OPS), ) -_opBW_device_dtype_B_Mq_Mkv_H_K_Kv_ids = _gen_ids(_opBW_device_dtype_B_Mq_Mkv_H_K_Kv) - -_opFW_device_dtype_B_Mq_Mkv_H_K_Kv__xs = list( - _generate_op_device_dtype_B_Mq_Mkv_H_K_Kv(ALL_FW_OPS, one_shape_per_op=True) +parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs = pytest.mark.parametrize( + "opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv", + **_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_FW_OPS, max_shapes_per_op=1), ) -_opFW_device_dtype_B_Mq_Mkv_H_K_Kv__xs_ids = _gen_ids( - _opFW_device_dtype_B_Mq_Mkv_H_K_Kv__xs +parametrize_opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = pytest.mark.parametrize( + "opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv", + **_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_BW_OPS), ) -_opBW_device_dtype_B_Mq_Mkv_H_K_Kv__xs = list( - _generate_op_device_dtype_B_Mq_Mkv_H_K_Kv(ALL_BW_OPS, one_shape_per_op=True) -) -_opBW_device_dtype_B_Mq_Mkv_H_K_Kv__xs_ids = _gen_ids( - _opBW_device_dtype_B_Mq_Mkv_H_K_Kv__xs +parametrize_opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs = pytest.mark.parametrize( + "opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv", + **_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_BW_OPS, max_shapes_per_op=1), ) @@ -196,14 +199,20 @@ def ref_attention(q, k, v, attn_bias=None, drop_mask=None, p=0.0, scale=None): attn = q @ k.transpose(-2, -1) if attn_bias is not None: if isinstance(attn_bias, xformers.ops.AttentionBias): - attn_bias = attn_bias.materialize( - (q.shape[0], q.shape[1], k.shape[1]), + # Always create in B,H,Mq,Mk format + attn_bias_tensor = attn_bias.materialize( + (q.shape[0], 1, q.shape[1], k.shape[1]), device=q.device, dtype=torch.float32, ) - if attn_bias.shape[0] != attn.shape[0]: - attn_bias = bmk2bmhk(attn_bias, k.shape[2]) - attn = attn + attn_bias.float() + else: + attn_bias_tensor = attn_bias + if attn_bias_tensor.ndim == 4: + assert q.shape[0] == attn_bias_tensor.shape[0] * attn_bias_tensor.shape[1] + attn_bias_tensor = attn_bias_tensor.reshape( + [-1, *attn_bias_tensor.shape[2:]] + ) + attn = attn + attn_bias_tensor.float() attn = attn.softmax(-1) if drop_mask is not None: attn = attn * (drop_mask / (1 - p)) @@ -218,6 +227,12 @@ def T(t): [t.shape[0] * t.shape[2], t.shape[1], t.shape[3]] ) + if isinstance(attn_bias, xformers.ops.AttentionBias): + attn_bias = attn_bias.materialize( + (q.shape[0], q.shape[2], q.shape[1], k.shape[1]), + device=q.device, + dtype=torch.float32, + ).reshape([q.shape[0] * q.shape[2], q.shape[1], k.shape[1]]) out = ref_attention(T(q), T(k), T(v), attn_bias, scale=scale) out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]]) return out.permute((0, 2, 1, 3)) @@ -251,15 +266,19 @@ def create_attn_bias( device, dtype, requires_grad: bool, + fmt: str, ): - if bias_type is None: + if bias_type is None or isinstance(None, bias_type): return None if bias_type is torch.Tensor: + if fmt == "BMK": + batch_size *= num_heads + num_heads = 1 attn_bias = ( - torch.randn((batch_size * num_heads, 1, kv_len), device=device, dtype=dtype) + torch.randn((batch_size, num_heads, 1, kv_len), device=device, dtype=dtype) * 3 ) - attn_bias = attn_bias.expand(batch_size * num_heads, q_len, kv_len) + attn_bias = attn_bias.expand(batch_size, num_heads, q_len, kv_len) if requires_grad: attn_bias.requires_grad_(True) return attn_bias @@ -279,8 +298,10 @@ def create_attn_bias( fmha.attn_bias.BlockDiagonalMask, fmha.attn_bias.BlockDiagonalCausalMask, ]: + # This bias is not supported in BMK format + assert fmt == "BMHK" block_diag = fmha.attn_bias.BlockDiagonalMask.from_seqlens( - *_rand_seqlens(batch_size * num_heads, q_len, kv_len) + *_rand_seqlens(batch_size, q_len, kv_len) ) if bias_type is fmha.attn_bias.BlockDiagonalCausalMask: block_diag = block_diag.make_causal() @@ -303,9 +324,10 @@ def get_bias_grad(attn_bias, clear: bool = False) -> Optional[torch.Tensor]: def create_tensors( - op, + op: Type[AttentionOpBase], device, dtype, + attn_bias_type, B, q_len, kv_len, @@ -313,7 +335,6 @@ def create_tensors( k, kv, *, - attn_bias_type=None, attn_bias_requires_grad: bool = False, fmt: str = "BMK", ): @@ -329,6 +350,8 @@ def create_tensors( key = torch.randn((B, kv_len, h, k), device=device, dtype=dtype).mul_(scale) value = torch.randn((B, kv_len, h, kv), device=device, dtype=dtype).mul_(scale) + if fmt == "BMK" and not fmha.common._is_bias_type_supported_in_BMK(attn_bias_type): + attn_bias_type = None attn_bias = None if attn_bias_type is not None: attn_bias = create_attn_bias( @@ -340,6 +363,7 @@ def create_tensors( dtype=dtype, device=device, requires_grad=attn_bias_requires_grad, + fmt=fmt, ) if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalMask): query, key, value = [ @@ -347,8 +371,9 @@ def create_tensors( ] inputs = fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias) - if not op.supports(inputs): - err_msg = f"{op.NAME}: unsupported ({inputs})" + reasons = op.not_supported_reasons(inputs) + if reasons: + err_msg = f"{op.NAME}: unsupported ({'/'.join(reasons)})" # Ensure we free memory to avoid OOMs del query, key, value, attn_bias, inputs pytest.skip(err_msg) @@ -371,17 +396,9 @@ def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor: @pytest.mark.parametrize("fmt", ["BMK", "BMHK"]) @pytest.mark.parametrize("packed", [False, True]) -@pytest.mark.parametrize( - "attn_bias_type", [None, xformers.ops.LowerTriangularMask, torch.Tensor] -) -@pytest.mark.parametrize( - "op_device_dtype_B_Mq_Mkv_H_K_Kv", - _opFW_device_dtype_B_Mq_Mkv_H_K_Kv, - ids=_opFW_device_dtype_B_Mq_Mkv_H_K_Kv_ids, -) +@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv def test_forward( - op_device_dtype_B_Mq_Mkv_H_K_Kv, - attn_bias_type, + opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, packed, fmt, ): @@ -389,19 +406,22 @@ def test_forward( op, device, dtype, + bias_type, batch_size, q_len, kv_len, h, k, kv, - ) = op_device_dtype_B_Mq_Mkv_H_K_Kv + ) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv if packed and not (k == kv and q_len == kv_len): pytest.skip( f"packed incompatible with `k ({k}) != kv ({kv})` or `q_len ({q_len}) != kv_len ({kv_len})`" ) + if fmt == "BMK" and not fmha.common._is_bias_type_supported_in_BMK(bias_type): + pytest.skip("BMK incompatible with this bias") query, key, value, attn_bias = create_tensors( - *op_device_dtype_B_Mq_Mkv_H_K_Kv, attn_bias_type=attn_bias_type, fmt="BMHK" + *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt="BMHK" if packed else fmt ) if packed: @@ -410,6 +430,18 @@ def test_forward( # bm3hk -> 3bhmk -> 3Bmk c = c.permute(2, 0, 3, 1, 4).view([3, -1, q_len, k]) query, key, value = c[0], c[1], c[2] + # Re-create bias in the right format + attn_bias = create_attn_bias( + bias_type=bias_type, + batch_size=batch_size, + num_heads=h, + q_len=q_len, + kv_len=kv_len, + device=device, + dtype=dtype, + requires_grad=False, + fmt=fmt, + ) else: # bm3hk -> 3 x bmhk query, key, value = xformers.ops.unbind(c, 2) @@ -447,72 +479,88 @@ def test_key_query_all_ones(device, q_len, kv_len, batch_size, k_len): assert_allclose(out, ref, atol=1e-5) -@pytest.mark.parametrize( - "op_device_dtype_B_Mq_Mkv_H_K_Kv", - _opFW_device_dtype_B_Mq_Mkv_H_K_Kv, - ids=_opFW_device_dtype_B_Mq_Mkv_H_K_Kv_ids, -) -def test_logsumexp(op_device_dtype_B_Mq_Mkv_H_K_Kv): +def _block_diag_reshape_lse( + lse: torch.Tensor, q_seqinfo: fmha.attn_bias._SeqLenInfo +) -> torch.Tensor: + """LSE can be padded, let's remove the padding""" + parts = [] + for slice, start, end in zip( + lse.unbind(0), q_seqinfo.cu_seqlen_py, q_seqinfo.cu_seqlen_py[1:] + ): + parts.append(slice[:, : end - start]) + return torch.cat(parts, dim=1).unsqueeze(1) + + +@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv +def test_logsumexp(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv): ( op, device, dtype, + bias_type, batch_size, q_len, kv_len, h, k, kv, - ) = op_device_dtype_B_Mq_Mkv_H_K_Kv + ) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv query, key, value, attn_bias = create_tensors( - *op_device_dtype_B_Mq_Mkv_H_K_Kv, fmt="BMK" + *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt="BMK" ) _out, lse = xformers.ops.memory_efficient_attention_forward_requires_grad( - query, key, value, op=op + query, + key, + value, + op=op, + attn_bias=attn_bias, ) - ref_lse = ((query.float() / k**0.5) @ key.float().transpose(-2, -1)).logsumexp(-1) - + attn = (query.float() / k**0.5) @ key.float().transpose(-2, -1) + if attn_bias is not None: + if isinstance(attn_bias, xformers.ops.AttentionBias): + tensor_bias = attn_bias.materialize( + (query.shape[0], 1, query.shape[1], key.shape[1]), + device=query.device, + dtype=torch.float32, + ) + else: + assert isinstance(attn_bias, torch.Tensor) + tensor_bias = attn_bias + if tensor_bias.ndim == 4: + tensor_bias = tensor_bias.reshape([-1, *tensor_bias.shape[2:]]) + attn = attn + tensor_bias.float() + ref_lse = attn.logsumexp(-1) + if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalMask): + lse = _block_diag_reshape_lse(lse, attn_bias.q_seqinfo) assert_allclose(lse[:, 0, : ref_lse.shape[1]], ref_lse, atol=2e-4) @pytest.mark.parametrize("fmt", ["BMK", "BMHK"]) -@pytest.mark.parametrize( - "attn_bias_cfg", # (type(bias), bias.requires_grad) - [ - (None, False), - (xformers.ops.LowerTriangularMask, False), - (torch.Tensor, True), - (torch.Tensor, False), - ], -) @pytest.mark.parametrize("grad_out_contiguous", [False, True]) -@pytest.mark.parametrize( - "op_device_dtype_B_Mq_Mkv_H_K_Kv", - _opBW_device_dtype_B_Mq_Mkv_H_K_Kv, - ids=_opBW_device_dtype_B_Mq_Mkv_H_K_Kv_ids, -) +@parametrize_opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv def test_backward( - op_device_dtype_B_Mq_Mkv_H_K_Kv, + opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, grad_out_contiguous, - attn_bias_cfg, fmt, ): - attn_bias_type, attn_bias_requires_grad = attn_bias_cfg ( op_bw, device, dtype, + bias_type, batch_size, q_len, kv_len, h, k, kv, - ) = op_device_dtype_B_Mq_Mkv_H_K_Kv + ) = opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv + attn_bias_requires_grad = ( + random.Random(q_len + kv_len * batch_size).randint(0, 1) > 0 + ) query, key, value, attn_bias = create_tensors( - *op_device_dtype_B_Mq_Mkv_H_K_Kv, - attn_bias_type=attn_bias_type, + *opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, attn_bias_requires_grad=attn_bias_requires_grad, fmt=fmt, ) @@ -555,7 +603,6 @@ def test_backward( ].expand_as(out) out.backward(grad_out) - del out if qkv is None and op_bw == fmha.cutlass.BwOp: assert query.stride() == query.grad.stride() @@ -576,6 +623,16 @@ def test_backward( ref = ref_attention(query, key, value, attn_bias) ref.backward(grad_out) + + assert_allclose( + out.float(), + ref.float(), + "fw pass", + atol=op_fw.ERROR_ATOL[dtype], + rtol=op_fw.ERROR_RTOL.get(dtype, 1e-5), + ) + + del out del grad_out del ref @@ -849,35 +906,57 @@ def test_memory_efficient_attention_full_block_masked( assert_allclose(grad_v, value.grad, "grad_v", atol=atol) -@pytest.mark.parametrize( - "op_device_dtype_B_Mq_Mkv_H_K_Kv", - _opFW_device_dtype_B_Mq_Mkv_H_K_Kv__xs, - ids=_opFW_device_dtype_B_Mq_Mkv_H_K_Kv__xs_ids, -) +@pytest.mark.parametrize("fmt", ["BMK", "BMHK"]) +@parametrize_opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs +def test_lowlevel_api_shapes(opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt): + query, key, value, attn_bias = create_tensors( + *opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt=fmt + ) + grad_out = torch.ones_like(query) + query.requires_grad_(True) + key.requires_grad_(True) + value.requires_grad_(True) + + out, lse = xformers.ops.memory_efficient_attention_forward_requires_grad( + query, key, value, attn_bias + ) + assert out.ndim == query.ndim + dq, dk, dv = xformers.ops.memory_efficient_attention_backward( + grad_out, out, lse, query, key, value, attn_bias + ) + assert dq.shape == query.shape + assert dk.shape == key.shape + assert dv.shape == value.shape + + +@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs def test_cuda_streams( - op_device_dtype_B_Mq_Mkv_H_K_Kv, + opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, ): ( op, device, dtype, + bias_type, batch_size, q_len, kv_len, h, k, kv, - ) = op_device_dtype_B_Mq_Mkv_H_K_Kv + ) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv if device != "cuda": pytest.skip("Not CUDA") # Needs to be big enough so kernels take some time # as we are trying to do a race-condition here q_len = 1024 kv_len = 1024 - op_device_dtype_B_Mq_Mkv_H_K_Kv = [ + bias_type = None + opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = [ op, device, dtype, + bias_type, batch_size, q_len, kv_len, @@ -889,7 +968,7 @@ def test_cuda_streams( s_lopri = torch.cuda.Stream(priority=0) with torch.cuda.stream(s_lopri): query, key, value, attn_bias = create_tensors( - *op_device_dtype_B_Mq_Mkv_H_K_Kv, attn_bias_type=None, fmt="BMHK" + *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt="BMHK" ) # Queue a lot of kernels for i in range(20): @@ -913,32 +992,29 @@ def test_cuda_streams( ) -@pytest.mark.parametrize( - "op_device_dtype_B_Mq_Mkv_H_K_Kv", - argvalues=_opBW_device_dtype_B_Mq_Mkv_H_K_Kv__xs, - ids=_opBW_device_dtype_B_Mq_Mkv_H_K_Kv__xs_ids, -) -def test_custom_scale(op_device_dtype_B_Mq_Mkv_H_K_Kv): - torch.manual_seed(42) +@parametrize_opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs +def test_custom_scale(opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv): p = 0.0 - scale = 1 + scale = 1.0 ( op_bw, device, dtype, _, + _, q_len, kv_len, _, k, _, - ) = op_device_dtype_B_Mq_Mkv_H_K_Kv + ) = opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv + torch.manual_seed(q_len + kv_len + k) if device != "cuda": pytest.skip("Not CUDA") query, key, value, attn_bias = create_tensors( - *op_device_dtype_B_Mq_Mkv_H_K_Kv, attn_bias_type=None, fmt="BMK" + *opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt="BMK" ) inputs = fmha.Inputs( query=query, key=key, value=value, attn_bias=attn_bias, scale=scale @@ -949,30 +1025,36 @@ def test_custom_scale(op_device_dtype_B_Mq_Mkv_H_K_Kv): key.requires_grad_(True) value.requires_grad_(True) - if not op_fw.supports(inputs): - pytest.skip(f"{op_fw.NAME}: unsupported ({inputs})") - if not op_bw.supports(inputs): - pytest.skip(f"{op_bw.NAME}: unsupported ({inputs})") + reasons = op_fw.not_supported_reasons(inputs) + if reasons: + pytest.skip(f"{op_fw.NAME}: unsupported ({'/'.join(reasons)})") + reasons = op_bw.not_supported_reasons(inputs) + if reasons: + pytest.skip(f"{op_bw.NAME}: unsupported ({'/'.join(reasons)})") + # NOTE: we still need to scale the inputs to not blowup + # the pre-softmax values (numerical stability) + s = k**-0.5 out = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias, p, scale, op=(op_fw, op_bw) + query * s, key, value, attn_bias, p, scale, op=(op_fw, op_bw) ) out.backward(grad_out) grad_q, grad_k, grad_v = query.grad, key.grad, value.grad query.grad = key.grad = value.grad = None - ref = ref_attention(query, key, value, attn_bias, None, p, scale) + ref = ref_attention(query * s, key, value, attn_bias, None, p, scale) ref.backward(grad_out) ref_grad_q, ref_grad_k, ref_grad_v = query.grad, key.grad, value.grad query.grad = key.grad = value.grad = None atol = op_fw.ERROR_ATOL[dtype] - assert_allclose(out.float(), ref.float(), atol=atol) + rtol = op_fw.ERROR_RTOL[dtype] + assert_allclose(out.float(), ref.float(), "out", atol=atol, rtol=rtol) atol = op_bw.ERROR_ATOL[dtype] rtol = op_bw.ERROR_RTOL[dtype] - assert_allclose(grad_q, ref_grad_q, atol=atol, rtol=rtol) - assert_allclose(grad_k, ref_grad_k, atol=atol, rtol=rtol) - assert_allclose(grad_v, ref_grad_v, atol=atol, rtol=rtol) + assert_allclose(grad_q, ref_grad_q, "grad_q", atol=atol, rtol=rtol) + assert_allclose(grad_k, ref_grad_k, "grad_k", atol=atol, rtol=rtol) + assert_allclose(grad_v, ref_grad_v, "grad_v", atol=atol, rtol=rtol) def apply_attention(query, key, value, attn_bias, op_fw, proj): @@ -984,13 +1066,9 @@ def apply_attention(query, key, value, attn_bias, op_fw, proj): @pytest.mark.parametrize("use_reentrant", [False, True]) -@pytest.mark.parametrize( - "op_device_dtype_B_Mq_Mkv_H_K_Kv", - _opFW_device_dtype_B_Mq_Mkv_H_K_Kv__xs, - ids=_opFW_device_dtype_B_Mq_Mkv_H_K_Kv__xs_ids, -) +@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs def test_grad_checkpointing( - op_device_dtype_B_Mq_Mkv_H_K_Kv, + opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, use_reentrant, ): fmt = "BMHK" @@ -998,16 +1076,29 @@ def test_grad_checkpointing( op, device, dtype, + bias_type, + batch_size, + q_len, + kv_len, + h, + k, + kv, + ) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv + bias_type = None + opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = ( + op, + device, + dtype, + bias_type, batch_size, q_len, kv_len, h, k, kv, - ) = op_device_dtype_B_Mq_Mkv_H_K_Kv + ) query, key, value, attn_bias = create_tensors( - *op_device_dtype_B_Mq_Mkv_H_K_Kv, - attn_bias_type=None, + *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt=fmt, ) qkv = None