-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fp8 support moe models #2928
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Some small comments.
@@ -63,7 +63,7 @@ def normalize_e4m3fn_to_e4m3fnuz( | |||
weight_scale: torch.Tensor, | |||
input_scale: Optional[torch.Tensor] = None, | |||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||
if weight.dtype == torch.float8_e4m3fn: | |||
if weight.dtype == torch.float8_e4m3fn and SYSTEM == "rocm": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function would now not normalize on SYSTEM != "rocm"
even if the data type is float8_e4m3fn
. I think either the function should be renamed to normalize_e4m3fn_to_native_float8
or this condition should not be there (and do the conversion regardless SYSTEM
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done renamed the function
if ( | ||
isinstance(weights.loader, DefaultWeightsLoader) | ||
and isinstance(weights.loader.weight_class, UnquantizedWeight) | ||
) or isinstance(weights.loader, HybridFP8UnquantLoader): | ||
cls = UnquantizedSparseMoELayer | ||
if ( | ||
isinstance(weights.loader, HybridFP8UnquantLoader) | ||
and weights.loader.to_fp8 | ||
): | ||
cls = FP8SparseMoELayer | ||
else: | ||
cls = UnquantizedSparseMoELayer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be better to flatten this now. Something like:
if isinstance(weights.loader, DefaultWeightsLoader) and isinstance(weights.loader.weight_class, UnquantizedWeight):
cls = UnquantizedSparseMoELayer
elif isinstance(weights.loader, HybridFP8UnquantLoader):
cls = FP8SparseMoELayer
elif #...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I flattened it, but the condition remains the same, because we always use the HybridFP8UnquantLoader
to load the weigths.
https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/utils/quantization.py#L202
|
||
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}: | ||
all_weight[i], all_weight_scales[i], current_input_scale = ( | ||
normalize_e4m3fn_to_e4m3fnuz( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I see why the condition was added, more and more in favor of renaming this to normalize_e4m3fn_to_native_float8
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed
What does this PR do?
As per title!
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.