Skip to content

Commit

Permalink
[float8] keep model.output as nn.Linear (high precision, not fp8) (p…
Browse files Browse the repository at this point in the history
…ytorch#469)

**keep model.output as nn.Linear**: it's a common practice to NOT apply
fp8 on final output layer
* specify `skip_fqn_list` in swapping
* when applying TP to model.output, use plain `ColwiseParallel` instead
of `Float8ColwiseParallel`

credit to @awgu, we do not need tokentizer vacab size to be divisible by
16 pytorch#461

1D TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4`

1D TP + float8 all-gather, compile mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4
--training.compile`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2 --training.compile`

1D TP + float8 all-gather trace: see float8 and all-gather in the trace
<img width="1611" alt="Screenshot 2024-07-19 at 1 16 59 PM"
src="https://github.com/user-attachments/assets/9a95dfd9-40e0-4133-b2bb-e22ddf5b8472">

2D + float8 all-gather trace: see float8 and FSDP collectives and TP
collectives
<img width="1038" alt="Screenshot 2024-07-19 at 1 29 59 PM"
src="https://github.com/user-attachments/assets/6a34bcaa-bcae-402b-9994-cc892554fec7">
  • Loading branch information
weifengpy authored Jul 19, 2024
1 parent 71b8eae commit 0c6f9a2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
4 changes: 3 additions & 1 deletion torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def maybe_build_fp8_linear(
)
with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather):
swap_linear_with_float8_linear(
model, scaling_type_w=TensorScalingType.DYNAMIC
model,
scaling_type_w=TensorScalingType.DYNAMIC,
skip_fqn_list=["output"],
)
logger.info(
f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}"
Expand Down
19 changes: 10 additions & 9 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def selective_checkpointing_context_fn():
return module


def get_tp_parallel_strategy(
def get_tp_parallel_strategy_for_transformer_block(
job_config: JobConfig,
model: nn.Module,
) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]:
Expand Down Expand Up @@ -346,13 +346,6 @@ def apply_tp(
"""Apply tensor parallelism."""

tp_mesh = world_mesh["tp"]
# Parallel styles used for transformer block linear weights and their
# inputs may be different for float8 linears
(
rowwise_parallel_weight,
colwise_parallel_weight,
prepare_module_input,
) = get_tp_parallel_strategy(job_config, model)
loss_parallel = parallel_dims.loss_parallel_enabled

# 1. Parallelize the embedding and shard its outputs (which are the first
Expand All @@ -368,14 +361,22 @@ def apply_tp(
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": colwise_parallel_weight(
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
},
)

# Parallel styles used for transformer block linear weights and their
# inputs may be different for float8 linears
(
rowwise_parallel_weight,
colwise_parallel_weight,
prepare_module_input,
) = get_tp_parallel_strategy_for_transformer_block(job_config, model)

# Apply tensor + sequence parallelism to every transformer block
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
# by folding (and unfolding) the batch dimension and the sequence dimension.
Expand Down

0 comments on commit 0c6f9a2

Please sign in to comment.