Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flash attention as custom op so dynamo can trace it #6875

Merged
merged 2 commits into from
Apr 3, 2024

Conversation

JackCaoG
Copy link
Collaborator

@JackCaoG JackCaoG commented Apr 2, 2024

No description provided.

@JackCaoG JackCaoG added the pallas label Apr 2, 2024
@JackCaoG JackCaoG requested a review from alanwaketan April 2, 2024 23:02
def flash_attention_non_xla(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = False):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add segment_ids as it is not implemented in flash_attention. Maybe I should add it here... @alanwaketan is flash_attention also a torch.tensor?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to support segment_ids. We can safely ignore it. I will drop the comment and the parameter later.

@JackCaoG
Copy link
Collaborator Author

JackCaoG commented Apr 3, 2024

@alanwaketan can I get a review for this one?

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@alanwaketan alanwaketan merged commit d938680 into master Apr 3, 2024
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants