Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove compiled_rmsnorm #535

Merged
merged 1 commit into from
Aug 20, 2024
Merged

Conversation

tianyu-l
Copy link
Contributor

@tianyu-l tianyu-l commented Aug 20, 2024

Stack from ghstack (oldest at bottom):

It seems this feature (#442) has been very unreliable, e.g.

  • it did not work well with TP in terms of numerics, despite passing unit test (add unit test for compiled_rmsnorm #506)
  • it does not work with PP as expected (still having the cuda context error which fused_rmsnorm experienced)
  • it used to work in 1D but seems to have no effect today (08/19), not sure why

Let's remove it for now. To get performant RMSNorm:

  • apply (block-level) torch.compile, or
  • use fused_rmsnorm in 1D or 2D (which also doesn't support PP well)
  • wait for a dedicated CUDA kernel

This PR also adjusts integration tests as we now use rmsnorm as default.

[ghstack-poisoned]
tianyu-l added a commit that referenced this pull request Aug 20, 2024
ghstack-source-id: ceb4fa54121be241633daf06a0ca2eb407667274
Pull Request resolved: #535
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 20, 2024
@tianyu-l tianyu-l merged commit 66e4e02 into gh/tianyu-l/21/base Aug 20, 2024
5 checks passed
tianyu-l added a commit that referenced this pull request Aug 20, 2024
ghstack-source-id: ceb4fa54121be241633daf06a0ca2eb407667274
Pull Request resolved: #535
@tianyu-l tianyu-l deleted the gh/tianyu-l/21/head branch August 20, 2024 19:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

compiled rms_norm not numerically accurate (fails to produce good loss curves) when run under tp
3 participants