Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dtype, fix RMS norm for FP16 #8641

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Add dtype, fix RMS norm for FP16 #8641

wants to merge 2 commits into from

Conversation

metascroy
Copy link
Contributor

@metascroy metascroy commented Feb 23, 2025

Llama1B quality in CoreML is bad due to FP16 arithmetic. Here is a sample of generated text:

"Once upon a time,we had our way,as we navigated,through the vast expanse of our understanding,as we journeyed,through the treacherous terrain of our experiences,as we faced our fears,as we stood tall,as we confronted our demons,as we emerged victorious,as we transcend our limits,as we ascend,as we unite,as we become,as we lose ourselves,as we search,as we remember,as we come back,as we return,as we find myself,as I am,as I am not,as I stand tall,as I hear my voice,as I remember,as I forgive,as I am, as I am, as I become, as I lose myself, as I find myself, as I remember, as I come back, as I return, as I find myself, as I am, as I am not, as I stand tall, as I Hear My Voice, as I Remember, as I Find Myself, as I am, as I Become, as I Lose Myself, as I Find Myself, asных, as I Become, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose My"

The corresponding FP16 eager mode model has much better generated text:

"Once upon a time, in a small village nestled between two great mountains, there lived a young girl named Akira. She was a curious and adventurous soul, with a heart full of wonder and a mind full of questions. Akira lived with her grandmother, a wise and kind woman named Kana, who taught her the ways of the world and the secrets of the universe.

One day, while exploring the village, Akira stumbled upon a mysterious shop tucked away in a quiet alley. The sign above the door read "Moonlit Curios and Antiques," and the windows were filled with a dazzling array of strange and exotic objects. Akira felt an inexplicable pull towards the shop, as if the very fabric of the universe was calling to her."

The discrepancy is that the eager mode model actually computes the RMSNorm in FP32 due to a cast operation (which CoreML appears to ignore):

self._norm(x.float()).type_as(x)

Moreover, the norm computation appears unstable in FP16 and gives bad results. We can improve the numeric quality of the norm in FP16 by first dividing x by its maximum absolute value. Here is the generated text from CoreML in FP16 after this change:

"Once upon a time, in a small village nestled in the rolling hills of Provence, there lived a young girl named Colette. Colette was a curious and adventurous soul, with a heart full of wonder and a mind full of questions. She spent most of her days exploring the village, visiting the local market, and listening to the tales of the old villagers.

One day, while wandering through the village, Colette stumbled upon a small, mysterious shop tucked away on a quiet street. The sign above the door read "Curios and Wonders," and the windows were filled with a dazzling array of strange and exotic objects. Colette's curiosity was piqued, and she pushed open the door to reveal a dimly lit interior filled with the scent of old books and dust."

Note, for 4-bit channelwise quantization, the results do not look good even after this change. The ideal solution is to do QAT for llama1B with 4-bit channelwise quantization + FP16 arithmetic.

Copy link

pytorch-bot bot commented Feb 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/8641

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit e87347d with merge base 366d87e (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 23, 2025
Copy link

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@metascroy
Copy link
Contributor Author

@YifanShenSZ It is probably a bug for CoreML to ignore the cast. It was presumably added because FP16 arithmetic was not sufficient enough.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants