Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Support segment ids in flash attention #6943

Merged
merged 11 commits into from
May 1, 2024

Conversation

alanwaketan
Copy link
Collaborator

@alanwaketan alanwaketan commented Apr 19, 2024

Summary:
This PR is to add segment ids to the flash attention wrapper. The segment ids are a way to create an attention mask where each token can only attend to other tokens within the same segment. The mask is therefore a block diagonal matrix.

To support it, we further split the flash attention forward into tracing and execution part, and implement all the shape operations to make it compatible with the kernel.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py

@alanwaketan alanwaketan self-assigned this Apr 19, 2024
@alanwaketan alanwaketan force-pushed the alanwaketan/fa_segment_ids branch from b6a8ed8 to b9cfc67 Compare April 26, 2024 01:52
@JackCaoG
Copy link
Collaborator

Is this ready for review?

if not save_residuals:
o = o[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_xla_tpu_custom_call always return an array.

@alanwaketan
Copy link
Collaborator Author

Is this ready for review?

I still need to add spmd and dynamo support. So not yet.

@alanwaketan alanwaketan marked this pull request as ready for review April 30, 2024 22:46
@alanwaketan
Copy link
Collaborator Author

@JackCaoG Do you think we can do the SPMD and dynamo parts later since the customer is not using either of them now?

@JackCaoG
Copy link
Collaborator

yea.. don't worry about SPMD and dynamo for this pr, let's do that in a separate pr..


@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_wrapper_segment_ids_2(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you have 2 test, one compare to native torch, one compare to jax?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, the JAX test is written before I figure out how to do the non-kernel mask.

Comment on lines +695 to +703
torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
q_segment_ids = torch.zeros(4, 128).to("xla")
kv_segment_ids = torch.zeros(4, 128).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we refactor this part out in a helper function in this test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually just refactor this part out and uses it on all tests, it is the same for all tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the tensor initializations? Those are kinda of expected paperworks. I don't think it's necessary to improve...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will leave that to you. When I see two large chunks of codes that looks similar, I usually tried to find how they are different. It confused me a bit when I realized it is the same code repeating over and over.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, for testing, it's sometime hard to avoid... haha

@@ -357,18 +418,22 @@ def backward(ctx, grad_output):
grad_v = xs.disable_manual_sharding(
grad_v, partition_spec, full_shape, mesh=mesh).global_tensor

return grad_q, grad_k, grad_v, None, None, None
return grad_q, grad_k, grad_v, None, None, None, None, None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to return these Nones?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the rule of the autograd.Function where all the inputs passed in the forward need to have the corresponding grads. For inputs that we don't diff on, we return None.

@JackCaoG JackCaoG added the tpuci label Apr 30, 2024
@alanwaketan
Copy link
Collaborator Author

Thanks, Jack.

@alanwaketan alanwaketan merged commit 400bd0c into master May 1, 2024
21 checks passed
@alanwaketan alanwaketan deleted the alanwaketan/fa_segment_ids branch May 1, 2024 18:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants