-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Pallas] Support segment ids in flash attention #6943
Conversation
b6a8ed8
to
b9cfc67
Compare
Is this ready for review? |
if not save_residuals: | ||
o = o[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_xla_tpu_custom_call always return an array.
I still need to add spmd and dynamo support. So not yet. |
@JackCaoG Do you think we can do the SPMD and dynamo parts later since the customer is not using either of them now? |
yea.. don't worry about SPMD and dynamo for this pr, let's do that in a separate pr.. |
|
||
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, | ||
"This test only works on TPUv3+.") | ||
def test_flash_attention_wrapper_segment_ids_2(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so you have 2 test, one compare to native torch, one compare to jax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, the JAX test is written before I figure out how to do the non-kernel mask.
torch.manual_seed(42) | ||
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") | ||
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") | ||
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") | ||
q_segment_ids = torch.zeros(4, 128).to("xla") | ||
kv_segment_ids = torch.zeros(4, 128).to("xla") | ||
q.retain_grad() | ||
k.retain_grad() | ||
v.retain_grad() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we refactor this part out in a helper function in this test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually just refactor this part out and uses it on all tests, it is the same for all tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean the tensor initializations? Those are kinda of expected paperworks. I don't think it's necessary to improve...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will leave that to you. When I see two large chunks of codes that looks similar, I usually tried to find how they are different. It confused me a bit when I realized it is the same code repeating over and over.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, for testing, it's sometime hard to avoid... haha
@@ -357,18 +418,22 @@ def backward(ctx, grad_output): | |||
grad_v = xs.disable_manual_sharding( | |||
grad_v, partition_spec, full_shape, mesh=mesh).global_tensor | |||
|
|||
return grad_q, grad_k, grad_v, None, None, None | |||
return grad_q, grad_k, grad_v, None, None, None, None, None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to return these Nones?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the rule of the autograd.Function where all the inputs passed in the forward need to have the corresponding grads. For inputs that we don't diff on, we return None.
Thanks, Jack. |
Summary:
This PR is to add segment ids to the flash attention wrapper. The segment ids are a way to create an attention mask where each token can only attend to other tokens within the same segment. The mask is therefore a block diagonal matrix.
To support it, we further split the flash attention forward into tracing and execution part, and implement all the shape operations to make it compatible with the kernel.
Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py