-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SLM] Fuse Add and RMSNorm #1627
Conversation
cc: @MasterJH5574 |
|
||
|
||
def get_add_rmsnorm_tir(hidden_size: int, is_decode=True): | ||
@T.prim_func(private=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose the fused add-rmsnorm operator could be expressed by TE and scheduled by Dlight easily
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it's not that easy. The tricky point is that add_rmsnorm function has 2 outputs: both results of add and rms_norm, which makes compute_inline/compute_at/reverse_compute_at all fail in such case. I have to write scheduled TIR to work around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Thanks for the elaboration!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the implication on performance? If the manual schedule can't generalize to all cases we can try supporting such pattern in cublas fusion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I heard from @MasterJH5574 that cublas fusion cannot successfully fuse matmul and divide_add now, so I create this small pass to unblock our effort on mlc serve perf profiling. In the long term, surely this can be replaced by better cublas fusion, but it doesn't hurt to work as a fallback or as a target to compare.
Whenever we use TIR, I know they are auto generated ones, but lets work to make sure they are human readable so that they establish positive examples demonstrating “TIR is actually great” |
Thanks @junrushao for pointing out. I accidentally used black to format this TIR and I will regenerate this TIR. |
f365ae4
to
c09b4a2
Compare
I think the PR is ready to merge in terms of code quality and correctness. I believe 1) @vinx13 has some further comments on its performance implication and generalizability, and 2) I have some concern about the prefill kernel which assumes |
24e1527
to
e0e078c
Compare
56692f7
to
7274713
Compare
This PR adds a fusion pass that applies to "binary add" and "RMSNorm".
This is a temporary workaround that allows us to fuse "add" into RMSNorm
once it is not fused into GEMM epilogue.