Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why swapping nn.Linear with FP8Linear brings throughput decrease for 7B LLAMA2-like model? #1199

Closed
season0528 opened this issue Oct 30, 2024 · 3 comments

Comments

@season0528
Copy link

season0528 commented Oct 30, 2024

Hi developers,

Thanks for such a great project that provides elegant interface for FP8 training!

For simplicity, I use dynamic scaling with FP8Linear. But when I tried to integrate FP8 functionalities with our training framework, I found the throughput of linear modules (such like the feedforward) decrease a lot.

  • traces that use naive bfloat16 torch.nn.Linear

    • CPU
      image
    • GPU
      image
  • traces that swap torch.nn.Linear with FP8Linear, via convert_to_float8_training API

    • CPU
      image
  • Code to reproduce the performance results: https://github.com/InternLM/InternEvo/tree/fp8_enable

  • Test environment:

    • 1x node, with 2x H100 GPUs
    • torch 2.5.0
    • train script: torchrun --nproc_per_node=2 --nnodes=1 --node_rank=0 train.py --config configs/7B_sft.py --launcher torch --seed 1024

Thanks in advance if anyone could provide some insights!

@season0528 season0528 changed the title Why swapping nn.Linear with FP8Linear brings decrease in throughput for LLAMA2 7b model? Why swapping nn.Linear with FP8Linear brings throughput decrease for 7B LLAMA2-like model? Oct 30, 2024
@vkuzo
Copy link
Contributor

vkuzo commented Oct 30, 2024

hi @zigzagcai , from the trace above and from the code in InternLM/InternEvo@e6f5562#diff-ed183d67207df065a11e1289f19d34cc2abbc5448dea952683cfe9728c342b95R22 I would guess that you are running in eager mode. This is expected to be slow as the float8 ops are not fused with surrounding ops. We require torch.compile to achieve good performance, have you tried enabling that? From your commit, the change would be

    float8_handler = Float8Handler()
    float8_handler.convert_to_float8_training(model)
    # enable torch.compile for improved performance
    model = torch.compile(model)

You can also try enabling torch.compile on smaller parts of the model if full-model compile is problematic. Here is an example of doing this from our torchtitan repository: https://github.com/pytorch/torchtitan/blob/dbb0520d51f3f92b5391b46b545959c7dcebe197/torchtitan/parallelisms/parallelize_llama.py#L304

@season0528
Copy link
Author

season0528 commented Oct 31, 2024

hi @zigzagcai , from the trace above and from the code in InternLM/InternEvo@e6f5562#diff-ed183d67207df065a11e1289f19d34cc2abbc5448dea952683cfe9728c342b95R22 I would guess that you are running in eager mode. This is expected to be slow as the float8 ops are not fused with surrounding ops. We require torch.compile to achieve good performance, have you tried enabling that? From your commit, the change would be

    float8_handler = Float8Handler()
    float8_handler.convert_to_float8_training(model)
    # enable torch.compile for improved performance
    model = torch.compile(model)

You can also try enabling torch.compile on smaller parts of the model if full-model compile is problematic. Here is an example of doing this from our torchtitan repository: https://github.com/pytorch/torchtitan/blob/dbb0520d51f3f92b5391b46b545959c7dcebe197/torchtitan/parallelisms/parallelize_llama.py#L304

Hi,

Thank you @vkuzo!

model = torch.compile(model) just works!

The FP8Linear throughput improved and showed performance gain, compared with naive bfloat16 nn.Linear.

@vkuzo
Copy link
Contributor

vkuzo commented Oct 31, 2024

awesome! Feel free to let us know with more questions!

@vkuzo vkuzo closed this as completed Oct 31, 2024
yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants