-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why swapping nn.Linear with FP8Linear brings throughput decrease for 7B LLAMA2-like model? #1199
Comments
hi @zigzagcai , from the trace above and from the code in InternLM/InternEvo@e6f5562#diff-ed183d67207df065a11e1289f19d34cc2abbc5448dea952683cfe9728c342b95R22 I would guess that you are running in eager mode. This is expected to be slow as the float8 ops are not fused with surrounding ops. We require torch.compile to achieve good performance, have you tried enabling that? From your commit, the change would be
You can also try enabling torch.compile on smaller parts of the model if full-model compile is problematic. Here is an example of doing this from our torchtitan repository: https://github.com/pytorch/torchtitan/blob/dbb0520d51f3f92b5391b46b545959c7dcebe197/torchtitan/parallelisms/parallelize_llama.py#L304 |
Hi, Thank you @vkuzo!
The FP8Linear throughput improved and showed performance gain, compared with naive bfloat16 |
awesome! Feel free to let us know with more questions! |
…-enable for bflo…" (pytorch#1197)" (pytorch#1199)
Hi developers,
Thanks for such a great project that provides elegant interface for FP8 training!
For simplicity, I use dynamic scaling with FP8Linear. But when I tried to integrate FP8 functionalities with our training framework, I found the throughput of linear modules (such like the feedforward) decrease a lot.
traces that use naive bfloat16
torch.nn.Linear
traces that swap
torch.nn.Linear
withFP8Linear
, via convert_to_float8_training APICode to reproduce the performance results: https://github.com/InternLM/InternEvo/tree/fp8_enable
Test environment:
torchrun --nproc_per_node=2 --nnodes=1 --node_rank=0 train.py --config configs/7B_sft.py --launcher torch --seed 1024
Thanks in advance if anyone could provide some insights!
The text was updated successfully, but these errors were encountered: