Add dtype, fix RMS norm for FP16 #8641
Open
+199
−65
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Llama1B quality in CoreML is bad due to FP16 arithmetic. Here is a sample of generated text:
The corresponding FP16 eager mode model has much better generated text:
The discrepancy is that the eager mode model actually computes the RMSNorm in FP32 due to a cast operation (which CoreML appears to ignore):
Moreover, the norm computation appears unstable in FP16 and gives bad results. We can improve the numeric quality of the norm in FP16 by first dividing x by its maximum absolute value. Here is the generated text from CoreML in FP16 after this change:
Note, for 4-bit channelwise quantization, the results do not look good even after this change. The ideal solution is to do QAT for llama1B with 4-bit channelwise quantization + FP16 arithmetic.