From 0c6f9a24de0bf1dd9548e3691de7b228179bbdd8 Mon Sep 17 00:00:00 2001
From: "Wei (Will) Feng" <134637289+weifengpy@users.noreply.github.com>
Date: Fri, 19 Jul 2024 16:07:12 -0700
Subject: [PATCH] [float8] keep model.output as `nn.Linear` (high precision,
not fp8) (#469)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
**keep model.output as nn.Linear**: it's a common practice to NOT apply
fp8 on final output layer
* specify `skip_fqn_list` in swapping
* when applying TP to model.output, use plain `ColwiseParallel` instead
of `Float8ColwiseParallel`
credit to @awgu, we do not need tokentizer vacab size to be divisible by
16 https://github.com/pytorch/torchtitan/issues/461
1D TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4`
1D TP + float8 all-gather, compile mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4
--training.compile`
2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2`
2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2 --training.compile`
1D TP + float8 all-gather trace: see float8 and all-gather in the trace
2D + float8 all-gather trace: see float8 and FSDP collectives and TP
collectives
---
torchtitan/float8_linear.py | 4 +++-
torchtitan/parallelisms/parallelize_llama.py | 19 ++++++++++---------
2 files changed, 13 insertions(+), 10 deletions(-)
diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py
index 50c971ae..770531d5 100644
--- a/torchtitan/float8_linear.py
+++ b/torchtitan/float8_linear.py
@@ -74,7 +74,9 @@ def maybe_build_fp8_linear(
)
with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather):
swap_linear_with_float8_linear(
- model, scaling_type_w=TensorScalingType.DYNAMIC
+ model,
+ scaling_type_w=TensorScalingType.DYNAMIC,
+ skip_fqn_list=["output"],
)
logger.info(
f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}"
diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py
index 33b9d6d3..634c70a0 100644
--- a/torchtitan/parallelisms/parallelize_llama.py
+++ b/torchtitan/parallelisms/parallelize_llama.py
@@ -117,7 +117,7 @@ def selective_checkpointing_context_fn():
return module
-def get_tp_parallel_strategy(
+def get_tp_parallel_strategy_for_transformer_block(
job_config: JobConfig,
model: nn.Module,
) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]:
@@ -346,13 +346,6 @@ def apply_tp(
"""Apply tensor parallelism."""
tp_mesh = world_mesh["tp"]
- # Parallel styles used for transformer block linear weights and their
- # inputs may be different for float8 linears
- (
- rowwise_parallel_weight,
- colwise_parallel_weight,
- prepare_module_input,
- ) = get_tp_parallel_strategy(job_config, model)
loss_parallel = parallel_dims.loss_parallel_enabled
# 1. Parallelize the embedding and shard its outputs (which are the first
@@ -368,7 +361,7 @@ def apply_tp(
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
- "output": colwise_parallel_weight(
+ "output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
@@ -376,6 +369,14 @@ def apply_tp(
},
)
+ # Parallel styles used for transformer block linear weights and their
+ # inputs may be different for float8 linears
+ (
+ rowwise_parallel_weight,
+ colwise_parallel_weight,
+ prepare_module_input,
+ ) = get_tp_parallel_strategy_for_transformer_block(job_config, model)
+
# Apply tensor + sequence parallelism to every transformer block
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
# by folding (and unfolding) the batch dimension and the sequence dimension.