Skip to content

Commit ec48219

Browse files
authored
[float8] keep model.output as nn.Linear (high precision, not fp8) (pytorch#469)
**keep model.output as nn.Linear**: it's a common practice to NOT apply fp8 on final output layer * specify `skip_fqn_list` in swapping * when applying TP to model.output, use plain `ColwiseParallel` instead of `Float8ColwiseParallel` credit to @awgu, we do not need tokentizer vacab size to be divisible by 16 pytorch#461 1D TP + float8 all-gather, eager mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4` 1D TP + float8 all-gather, compile mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4 --training.compile` 2D FSDP2 + TP + float8 all-gather, eager mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.enable_fsdp_float8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp --training.tensor_parallel_degree 2` 2D FSDP2 + TP + float8 all-gather, eager mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.enable_fsdp_float8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp --training.tensor_parallel_degree 2 --training.compile` 1D TP + float8 all-gather trace: see float8 and all-gather in the trace <img width="1611" alt="Screenshot 2024-07-19 at 1 16 59 PM" src="https://github.com/user-attachments/assets/9a95dfd9-40e0-4133-b2bb-e22ddf5b8472"> 2D + float8 all-gather trace: see float8 and FSDP collectives and TP collectives <img width="1038" alt="Screenshot 2024-07-19 at 1 29 59 PM" src="https://github.com/user-attachments/assets/6a34bcaa-bcae-402b-9994-cc892554fec7">
1 parent 5d1b99f commit ec48219

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

torchtitan/float8_linear.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def maybe_build_fp8_linear(
7474
)
7575
with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather):
7676
swap_linear_with_float8_linear(
77-
model, scaling_type_w=TensorScalingType.DYNAMIC
77+
model,
78+
scaling_type_w=TensorScalingType.DYNAMIC,
79+
skip_fqn_list=["output"],
7880
)
7981
logger.info(
8082
f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}"

torchtitan/parallelisms/parallelize_llama.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def selective_checkpointing_context_fn():
117117
return module
118118

119119

120-
def get_tp_parallel_strategy(
120+
def get_tp_parallel_strategy_for_transformer_block(
121121
job_config: JobConfig,
122122
model: nn.Module,
123123
) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]:
@@ -346,13 +346,6 @@ def apply_tp(
346346
"""Apply tensor parallelism."""
347347

348348
tp_mesh = world_mesh["tp"]
349-
# Parallel styles used for transformer block linear weights and their
350-
# inputs may be different for float8 linears
351-
(
352-
rowwise_parallel_weight,
353-
colwise_parallel_weight,
354-
prepare_module_input,
355-
) = get_tp_parallel_strategy(job_config, model)
356349
loss_parallel = parallel_dims.loss_parallel_enabled
357350

358351
# 1. Parallelize the embedding and shard its outputs (which are the first
@@ -368,14 +361,22 @@ def apply_tp(
368361
output_layouts=Shard(1),
369362
),
370363
"norm": SequenceParallel(),
371-
"output": colwise_parallel_weight(
364+
"output": ColwiseParallel(
372365
input_layouts=Shard(1),
373366
output_layouts=Shard(-1) if loss_parallel else Replicate(),
374367
use_local_output=not loss_parallel,
375368
),
376369
},
377370
)
378371

372+
# Parallel styles used for transformer block linear weights and their
373+
# inputs may be different for float8 linears
374+
(
375+
rowwise_parallel_weight,
376+
colwise_parallel_weight,
377+
prepare_module_input,
378+
) = get_tp_parallel_strategy_for_transformer_block(job_config, model)
379+
379380
# Apply tensor + sequence parallelism to every transformer block
380381
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
381382
# by folding (and unfolding) the batch dimension and the sequence dimension.

0 commit comments

Comments
 (0)