From 6ff6501aa229871a21bbe9ed04673320126dcf34 Mon Sep 17 00:00:00 2001 From: kaixih Date: Thu, 1 Aug 2024 19:39:34 +0000 Subject: [PATCH 1/2] Init commit --- jax/_src/cudnn/fused_attention_stablehlo.py | 24 ++++++----- jax/_src/nn/functions.py | 46 ++++++++++++--------- tests/nn_test.py | 22 ++++++---- 3 files changed, 54 insertions(+), 38 deletions(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 262b8e2c140a..51a86fdcb978 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -618,11 +618,12 @@ def _dot_product_attention_fwd_batcher( *_, S, _, _ = key.shape B = math.prod(Bs) has_bias, _ = variadic_args + original_shape = query.shape # reshape to 4D shape query = jnp.reshape(query, (B,) + query.shape[-3:]) key = jnp.reshape(key, (B,) + key.shape[-3:]) value = jnp.reshape(value, (B,) + key.shape[-3:]) - if has_bias: + if has_bias and batch_dims[3] is not None: bias = jnp.reshape(bias, (B, N, T, S)) if has_padding(mask_type): q_seqlen = jnp.reshape(q_seqlen, (B, )) @@ -635,7 +636,7 @@ def _dot_product_attention_fwd_batcher( # reshape to original shape output = outputs[0] - output = jnp.reshape(output, query.shape) + output = jnp.reshape(output, original_shape) if is_training: activation = outputs[1] activation = jnp.reshape(activation, (*Bs, N, T)) @@ -660,11 +661,15 @@ def _dot_product_attention_bwd_batcher( *_, S, _, _ = key.shape B = math.prod(Bs) has_bias, has_dbias = variadic_args + original_query_shape = query.shape + original_key_shape = key.shape + original_value_shape = value.shape + original_bias_shape = bias.shape if has_bias else None # reshape to 4D shape query = jnp.reshape(query, (B,) + query.shape[-3:]) key = jnp.reshape(key, (B,) + key.shape[-3:]) value = jnp.reshape(value, (B,) + key.shape[-3:]) - if has_bias: + if has_bias and batch_dims[3] is not None: bias = jnp.reshape(bias, (B, N, T, S)) if has_padding(mask_type): q_seqlen = jnp.reshape(q_seqlen, (B, )) @@ -681,15 +686,14 @@ def _dot_product_attention_bwd_batcher( mask_type=mask_type, layout=layout, ) - grad_query, grad_key, grad_value = grads[:3] # reshape to original shape - grad_query = jnp.reshape(grad_query, query.shape) - grad_key = jnp.reshape(grad_key, key.shape) - grad_value = jnp.reshape(grad_value, value.shape) + grads[0] = jnp.reshape(grads[0], original_query_shape) + grads[1] = jnp.reshape(grads[1], original_key_shape) + grads[2] = jnp.reshape(grads[2], original_value_shape) if has_dbias: - grad_bias = grads[3] - grad_bias = jnp.reshape(grad_bias, bias.shape) - return grads + (grad_bias,), out_bdims + (query_bdim,) + assert has_bias + grads[3] = jnp.reshape(grads[3], original_bias_shape) + out_bdims += (batch_dims[3],) return grads, out_bdims # custom partitioning diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 32d543a27966..5d7c941615ea 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -853,9 +853,9 @@ def dot_product_attention( query: ArrayLike, key: ArrayLike, value: ArrayLike, - *, bias: ArrayLike | None = None, mask: ArrayLike | None = None, + *, scale: float | None = None, is_causal: bool = False, implementation: Literal['xla', 'cudnn'] | None = None) -> Array: @@ -882,20 +882,20 @@ def dot_product_attention( G = number of groups, which equals to N // K Args: - query: query array; shape :code:`(BTNH)` - key: key array: shape :code:`(BSKH)`. When `K` equals `N`, multi-headed - attention (MHA: https://arxiv.org/abs/1706.03762) is performed. Otherwise, - grouped query attention (GQA: https://arxiv.org/abs/2305.13245) is performed - if `N` is a multiple of `K`, and multi-query attention (MQA: - https://arxiv.org/abs/1911.02150) is performed if `K == 1` (a special case - of GQA). + query: query array; shape :code:`(BTNH|TNH)` + key: key array: shape :code:`(BSKH|SKH)`. When `K` equals `N`, multi-headed + attention (MHA https://arxiv.org/abs/1706.03762) is performed. Otherwise, + grouped query attention (GQA https://arxiv.org/abs/2305.13245) is + performed if `N` is a multiple of `K`, and multi-query attention (MQA + https://arxiv.org/abs/1911.02150) is performed if `K == 1` (a special case + of GQA). value: value array, should have the same shape as the `key` array. bias: optional, bias array to be added to logits; The shape must be 4D and - be broadcastable to :code:`(BNTS)`. + be broadcastable to :code:`(BNTS|NTS)`. mask: optional, mask array used to filter out logits. It is a boolean mask where `True` indicates the element should take part in attention. For an additive mask, users should pass it to `bias`. The shape must be 4D and be - broadcastable to :code:`(BNTS)`. + broadcastable to :code:`(BNTS|NTS)`. scale: scale for the logits. If None, the scale will be set to 1 divided by the square root of query's head dimension (i.e. H). is_causal: If true, causal attention will be applied. Note, some @@ -912,6 +912,18 @@ def dot_product_attention( Returns: An array of the attention output with the same shape as :code:`query`. """ + original_shape = jnp.asarray(query).shape + def _preprocess_array(t): + if t is None: + return t + t = jnp.asarray(t) + return t[None, ...] if t.ndim == 3 else t + query = _preprocess_array(query) + key = _preprocess_array(key) + value = _preprocess_array(value) + bias = _preprocess_array(bias) + mask = _preprocess_array(mask) + def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: if t.ndim != len(shape): raise ValueError(f"{name} ndim should be {len(shape)}, but got {t.ndim}") @@ -919,12 +931,6 @@ def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: if shape[i] != -1 and t.shape[i] != shape[i]: raise ValueError(f"{name} shape should be {shape}: but got {t.shape}") - query = jnp.asarray(query) - key = jnp.asarray(key) - value = jnp.asarray(value) - bias = bias if bias is None else jnp.asarray(bias) - mask = mask if mask is None else jnp.asarray(mask) - B, S, K, H = key.shape _check_has_shape(value, [B, S, K, H], 'value') _check_has_shape(query, [B, -1, -1, H], 'query') @@ -944,19 +950,21 @@ def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: match implementation: case 'xla': - return _dot_product_attention_xla( + out = _dot_product_attention_xla( query, key, value, bias, mask, is_causal=is_causal, scale=scale_val, ) case 'cudnn': mask_type = MaskType.CAUSAL if is_causal else MaskType.NO_MASK - return cudnn_dot_product_attention( + out = cudnn_dot_product_attention( query, key, value, bias, mask, scale=scale_val, mask_type=mask_type ) case None: # TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select # best backend. - return _dot_product_attention_xla( + out = _dot_product_attention_xla( query, key, value, bias, mask, is_causal=is_causal, scale=scale_val, ) case _: raise ValueError(f"Unsupported implementation option: {implementation}") + + return jnp.reshape(out, original_shape) diff --git a/tests/nn_test.py b/tests/nn_test.py index 455f04e5fd12..802ed1b2f1e2 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -57,10 +57,11 @@ class NNFunctionsTest(jtu.JaxTestCase): use_bias=[False, True], causal_mode=[None, 'is_causal', 'is_mask'], group_num=[1, 2, 4], + use_vmap=[False, True], impl=['xla', 'cudnn'], ) def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, - group_num, impl): + group_num, use_vmap, impl): if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if impl == 'cudnn' and dtype == jnp.float32: @@ -84,15 +85,17 @@ def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) if impl == 'cudnn': - lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias=bias, mask=causal_mask) + lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias, causal_mask) hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) self.assertIn('__cudnn$fmha', hlo) + if use_vmap: + sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V - out_ref = sdpa_ref(Q, K_ref, V_ref, bias=bias, mask=causal_mask) + out_ref = sdpa_ref(Q, K_ref, V_ref, bias, causal_mask) - out_ans = sdpa_ans(Q, K, V, bias=bias, mask=causal_mask) + out_ans = sdpa_ans(Q, K, V, bias, causal_mask) self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) @parameterized.product( @@ -100,10 +103,11 @@ def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, use_bias=[False, True], causal_mode=[None, 'is_causal', 'is_mask'], group_num=[1, 2, 4], + use_vmap=[False, True], impl=['xla', 'cudnn'], ) def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode, - group_num, impl): + group_num, use_vmap, impl): if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if impl == 'cudnn' and dtype == jnp.float32: @@ -127,16 +131,16 @@ def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode, K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) - fn_ref = lambda q, k, v, b, m: sdpa_ref(q, k, v, bias=b, mask=m) - _, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K_ref, V_ref, bias, causal_mask) + _, sdpa_vjp_ref = jax.vjp(sdpa_ref, Q, K_ref, V_ref, bias, causal_mask) dQ_ref, dK_ref, dV_ref, dbias_ref, _ = sdpa_vjp_ref(grad) if G != 1: dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3) dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3) sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) - fn_ans = lambda q, k, v, b, m: sdpa_ans(q, k, v, bias=b, mask=m) - _, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask) + if use_vmap: + sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) + _, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, causal_mask) dQ_ans, dK_ans, dV_ans, dbias_ans, _ = sdpa_vjp_ans(grad) if impl == 'cudnn': From 9f9e3e6d4e7a2955bca3b8d98ecd3c863700179b Mon Sep 17 00:00:00 2001 From: kaixih Date: Fri, 2 Aug 2024 19:55:28 +0000 Subject: [PATCH 2/2] Address comments --- jax/_src/nn/functions.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 5d7c941615ea..4aabf9521340 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -912,17 +912,19 @@ def dot_product_attention( Returns: An array of the attention output with the same shape as :code:`query`. """ - original_shape = jnp.asarray(query).shape - def _preprocess_array(t): - if t is None: - return t + output_shape = jnp.asarray(query).shape + def _ensure_4d(t): t = jnp.asarray(t) - return t[None, ...] if t.ndim == 3 else t - query = _preprocess_array(query) - key = _preprocess_array(key) - value = _preprocess_array(value) - bias = _preprocess_array(bias) - mask = _preprocess_array(mask) + dims_to_add = 4 - t.ndim + if dims_to_add > 0: + return jnp.expand_dims(t, axis=tuple(range(dims_to_add))) + return t + + query = _ensure_4d(query) + key = _ensure_4d(key) + value = _ensure_4d(value) + bias = _ensure_4d(bias) if bias is not None else None + mask = _ensure_4d(mask) if mask is not None else None def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: if t.ndim != len(shape): @@ -967,4 +969,4 @@ def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: case _: raise ValueError(f"Unsupported implementation option: {implementation}") - return jnp.reshape(out, original_shape) + return jnp.reshape(out, output_shape)