Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SLM] Fuse Add and RMSNorm #1627

Merged
merged 2 commits into from
Jan 25, 2024
Merged

[SLM] Fuse Add and RMSNorm #1627

merged 2 commits into from
Jan 25, 2024

Conversation

jinhongyii
Copy link
Member

@jinhongyii jinhongyii commented Jan 18, 2024

This PR adds a fusion pass that applies to "binary add" and "RMSNorm".
This is a temporary workaround that allows us to fuse "add" into RMSNorm
once it is not fused into GEMM epilogue.

@jinhongyii
Copy link
Member Author

cc: @MasterJH5574



def get_add_rmsnorm_tir(hidden_size: int, is_decode=True):
@T.prim_func(private=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose the fused add-rmsnorm operator could be expressed by TE and scheduled by Dlight easily

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's not that easy. The tricky point is that add_rmsnorm function has 2 outputs: both results of add and rms_norm, which makes compute_inline/compute_at/reverse_compute_at all fail in such case. I have to write scheduled TIR to work around.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks for the elaboration!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the implication on performance? If the manual schedule can't generalize to all cases we can try supporting such pattern in cublas fusion

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I heard from @MasterJH5574 that cublas fusion cannot successfully fuse matmul and divide_add now, so I create this small pass to unblock our effort on mlc serve perf profiling. In the long term, surely this can be replaced by better cublas fusion, but it doesn't hurt to work as a fallback or as a target to compare.

@junrushao
Copy link
Member

Whenever we use TIR, I know they are auto generated ones, but lets work to make sure they are human readable so that they establish positive examples demonstrating “TIR is actually great”

@junrushao junrushao closed this Jan 21, 2024
@junrushao junrushao reopened this Jan 21, 2024
@jinhongyii
Copy link
Member Author

Thanks @junrushao for pointing out. I accidentally used black to format this TIR and I will regenerate this TIR.

@junrushao junrushao force-pushed the add_norm_fuse branch 6 times, most recently from f365ae4 to c09b4a2 Compare January 25, 2024 07:35
@junrushao junrushao changed the title [SLIM] Fuse Add and RMSNorm [SLM] Fuse Add and RMSNorm Jan 25, 2024
@junrushao
Copy link
Member

I think the PR is ready to merge in terms of code quality and correctness. I believe 1) @vinx13 has some further comments on its performance implication and generalizability, and 2) I have some concern about the prefill kernel which assumes batch_size = 1. Anyways, it's a good start.

@junrushao junrushao merged commit b01b06c into mlc-ai:main Jan 25, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants