Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Support Flash Attention backward kernels #6870

Merged
merged 4 commits into from
Apr 2, 2024

Conversation

alanwaketan
Copy link
Collaborator

Summary:
This changes refactors custom_kernel.py to support all three new kernels from Pallas that are involved in Flash Attention backward calculations.

The refactoring includes:

  1. Adds support for static_argnums which will ignore some positional arguments for jax tracing.
  2. Separate jax tracing part out such that we can do the tracing alone.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py

@alanwaketan alanwaketan requested review from lsy323 and JackCaoG April 2, 2024 05:50
@alanwaketan alanwaketan self-assigned this Apr 2, 2024
xm.mark_step()

# TODO: I don't really know how to test the value. Let's do the shape check for now.
self.assertEqual(grad_q.shape, (3, 2, 128, 4))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we do the fwd and do res.backward then check the grad on q they should match?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The softmax is done differently. I don't think there is any guarantees.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

result should still be somewhat close right? we can tune down the precision. If the result return by this is dramatically different than the one that was computed using dot attention that seems wrong..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html

Softmax requires all the elements to produce the results, but flash attention chunks the data into blocks and use a technique called tiling to make sure the softmax still serve the functionality to stable the data. Since there are no aggregation, I don't know how the tiling softmax could produce the same results as the regular one.

In JAX, I have to use atol=1e-01, rtol=1e-01 to do the comparisons...

@alanwaketan
Copy link
Collaborator Author

Thanks, Jack!

@alanwaketan alanwaketan merged commit c54367c into master Apr 2, 2024
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants