-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Pallas] Support Flash Attention backward kernels #6870
Conversation
xm.mark_step() | ||
|
||
# TODO: I don't really know how to test the value. Let's do the shape check for now. | ||
self.assertEqual(grad_q.shape, (3, 2, 128, 4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we do the fwd and do res.backward
then check the grad on q they should match?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The softmax is done differently. I don't think there is any guarantees.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
result should still be somewhat close right? we can tune down the precision. If the result return by this is dramatically different than the one that was computed using dot attention that seems wrong..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html
Softmax requires all the elements to produce the results, but flash attention chunks the data into blocks and use a technique called tiling to make sure the softmax still serve the functionality to stable the data. Since there are no aggregation, I don't know how the tiling softmax could produce the same results as the regular one.
In JAX, I have to use atol=1e-01, rtol=1e-01
to do the comparisons...
Thanks, Jack! |
Summary:
This changes refactors custom_kernel.py to support all three new kernels from Pallas that are involved in Flash Attention backward calculations.
The refactoring includes:
Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py