Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[float8] keep model.output as
nn.Linear
(high precision, not fp8) (p…
…ytorch#469) **keep model.output as nn.Linear**: it's a common practice to NOT apply fp8 on final output layer * specify `skip_fqn_list` in swapping * when applying TP to model.output, use plain `ColwiseParallel` instead of `Float8ColwiseParallel` credit to @awgu, we do not need tokentizer vacab size to be divisible by 16 pytorch#461 1D TP + float8 all-gather, eager mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4` 1D TP + float8 all-gather, compile mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4 --training.compile` 2D FSDP2 + TP + float8 all-gather, eager mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.enable_fsdp_float8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp --training.tensor_parallel_degree 2` 2D FSDP2 + TP + float8 all-gather, eager mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.enable_fsdp_float8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp --training.tensor_parallel_degree 2 --training.compile` 1D TP + float8 all-gather trace: see float8 and all-gather in the trace <img width="1611" alt="Screenshot 2024-07-19 at 1 16 59 PM" src="https://github.com/user-attachments/assets/9a95dfd9-40e0-4133-b2bb-e22ddf5b8472"> 2D + float8 all-gather trace: see float8 and FSDP collectives and TP collectives <img width="1038" alt="Screenshot 2024-07-19 at 1 29 59 PM" src="https://github.com/user-attachments/assets/6a34bcaa-bcae-402b-9994-cc892554fec7">
- Loading branch information