-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVIDIA] Support vmap usage of jax.nn.dot_product_attention
#22830
Conversation
@@ -853,9 +853,9 @@ def dot_product_attention( | |||
query: ArrayLike, | |||
key: ArrayLike, | |||
value: ArrayLike, | |||
*, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems you can only specify the batch dims for the positional args in jax.vmap. For the keyward arguments, vmap will always use leading dim for the batch.
@@ -618,11 +618,12 @@ def _dot_product_attention_fwd_batcher( | |||
*_, S, _, _ = key.shape | |||
B = math.prod(Bs) | |||
has_bias, _ = variadic_args | |||
original_shape = query.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about output_shape
since you only use it for output
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
jax/_src/nn/functions.py
Outdated
if t is None: | ||
return t | ||
t = jnp.asarray(t) | ||
return t[None, ...] if t.ndim == 3 else t |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to assert that t.ndim
is 4 if it's not 3?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
jax/_src/nn/functions.py
Outdated
@@ -912,19 +912,25 @@ def dot_product_attention( | |||
Returns: | |||
An array of the attention output with the same shape as :code:`query`. | |||
""" | |||
original_shape = jnp.asarray(query).shape | |||
def _preprocess_array(t): | |||
if t is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I would personally move this out, since only bias
and mask
can be None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
jax/_src/nn/functions.py
Outdated
@@ -912,19 +912,25 @@ def dot_product_attention( | |||
Returns: | |||
An array of the attention output with the same shape as :code:`query`. | |||
""" | |||
original_shape = jnp.asarray(query).shape | |||
def _preprocess_array(t): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: how about _ensure_4d
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
jax/_src/nn/functions.py
Outdated
query, key, value, bias, mask, is_causal=is_causal, scale=scale_val, | ||
) | ||
case _: | ||
raise ValueError(f"Unsupported implementation option: {implementation}") | ||
|
||
return jnp.reshape(out, original_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you really just squeeze here, so you can do that instead of reshaping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I use reshape instead of squeeze, because I want to make sure if users want to do the batch aware call and pass (1, T, N, H) inputs and then they can still get the results of the same shape rather than squeezed shape of (T, N, H).
fn_ans = lambda q, k, v, b, m: sdpa_ans(q, k, v, bias=b, mask=m) | ||
_, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask) | ||
if use_vmap: | ||
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think your current implementation will fail if vmapped more than once, since it requires either a 3D or a 4D array.
Wdyt about handling the N>4 case by collapsing any extra leading dimensions into B
? @sbodenstein does this make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this doesn't work? I have tried this to mimic the 5D tensor and it works fine. Or do I miss your point?
Q = random.normal(keys[0], (B, B, T, N, H), dtype)
K = random.normal(keys[1], (B, B, S, N // G, H), dtype)
V = random.normal(keys[2], (B, B, S, N // G, H), dtype)
if use_bias:
bias = random.normal(keys[3], (1, N, T, S), dtype)
else:
bias = None
is_causal = causal_mode == 'is_causal'
causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)
if use_vmap:
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
K_ref = (jnp.repeat(K, G, axis=2) if G != 1 else K).reshape(B*B, S, N, H)
V_ref = (jnp.repeat(V, G, axis=2) if G != 1 else V).reshape(B*B, S, N, H)
Q_ref = Q.reshape(B*B, T, N, H)
out_ref = sdpa_ref(Q_ref, K_ref, V_ref, bias, causal_mask).reshape(B,B,T,N,H)
out_ans = sdpa_ans(Q, K, V, bias, causal_mask)
self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we assert that ndim
is 3 or 4 now? I would expect this to fail given a 5D input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it will fail if we directly pass in the 5D tensor. The behavior now is to support (1) 4D tensor for those who want to use the batch-aware API (2) 3D tensor for those wants to use the API in the context of vmap, meaning if users have 5D tensor, they need to use vmap as shown above.
Gentle ping @superbobry |
To address this request: #22760, this PR support no-batch inputs.