-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split reduction dim pass #975
Conversation
Now that I think more about it, it should be possible to largely simplify this logic using |
a8efc8a
to
a777210
Compare
Yup, that was really straight forward. Overall the pass behavior remains the same but now named ops are preserved after tiling. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks good to me.
Adds a pass to allow tiling contraction's innermost reduction dimension using serial loop with in-place accumulation.
Compared to other available transformations, this rewrite computes reduction sequentially with in-place accumulation which avoids temporary allocation and separate reduction operation. This tiling strategy is more friendly in terms of register and memory pressure more suitable for low-level GPU kernel generation. Similarly, restriction to the innermost dimension is there to simplify both usage and pass logic as the rewrite is geared toward progressive GEMM lowering.
Additionally, a GPU vectorization control flag is added to allow grouping of passes which target lowering through vector operations and might not be compatible with other existing lowering strategies.
Effectively, this pass perform separate K-dim split which is currently baked into linalg-to-xegpu lowering, and the two are not fully compatible.