Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reland "Support fusing broadcast transposes with attention" #19962

Merged
merged 2 commits into from
Feb 12, 2025

Conversation

IanWood1
Copy link
Contributor

Reland the changes to fold attention ops with broadcasts with a small tweak to AttentionOpDetail so that the batch dimensions are properly computed when an operand is broadcasted.

Original PR #19828
Revert PR #19835
Issue causing revert #19833

@@ -37,12 +37,10 @@ void AttentionOpDetail::inferFromIndexingMaps(AffineMap qMap, AffineMap kMap,
llvm::SmallDenseSet<int64_t> vSet = findPermutationsIndexingOperand(vMap);
llvm::SmallDenseSet<int64_t> oSet = findPermutationsIndexingOperand(oMap);

// B = (Q & K & O) U (K & V & O)
// B = (Q & V) U (K & O)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Is this true. cc @Groverkss

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We can actually make it just (Q & V) or just (K & O) also because they are the same.

K := batch x k1 x k2
O := batch x m x n

so their intersection has to be the batch dimensions.

Similarily,

Q := batch x m x k1
V := batch x k2 x n

so their intersection has to be the batch dimension

Interestingly, this actually made me realize we are doing GQA wrong. GQA is a broadcast on the M dimension, which is why it's only on Q.

…ree-org#19835)"

This reverts commit 9870a6d.

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
@IanWood1
Copy link
Contributor Author

I'll merge after double checking that this doesn't break 405b

@IanWood1 IanWood1 merged commit 5767be3 into iree-org:main Feb 12, 2025
43 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants