Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Support FA sm_scale #7035

Merged
merged 2 commits into from
May 8, 2024
Merged

[Pallas] Support FA sm_scale #7035

merged 2 commits into from
May 8, 2024

Conversation

alanwaketan
Copy link
Collaborator

@alanwaketan alanwaketan commented May 8, 2024

Summary:
This pull request is to support the scaling factor for flash attention's attention weight.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py

@alanwaketan alanwaketan requested a review from JackCaoG May 8, 2024 00:17
@alanwaketan alanwaketan self-assigned this May 8, 2024
@alanwaketan alanwaketan requested a review from wonjoolee95 May 8, 2024 00:18
@@ -428,12 +423,13 @@ def flash_attention(
causal=False,
q_segment_ids=None,
kv_segment_ids=None,
sm_scale=1.0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it this a training only thing so we don't need to add it for the dynamo?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can support it in dynamo too. I will have a follow up to add some missing dynamo parameters.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approve to unblock, have 1 comment that can be address in a follow up pr if needed.

@alanwaketan
Copy link
Collaborator Author

Thanks Jack!

@JackCaoG JackCaoG merged commit 1c31cde into master May 8, 2024
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants