diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 50c971ae..770531d5 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -74,7 +74,9 @@ def maybe_build_fp8_linear( ) with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather): swap_linear_with_float8_linear( - model, scaling_type_w=TensorScalingType.DYNAMIC + model, + scaling_type_w=TensorScalingType.DYNAMIC, + skip_fqn_list=["output"], ) logger.info( f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}" diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 33b9d6d3..634c70a0 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -117,7 +117,7 @@ def selective_checkpointing_context_fn(): return module -def get_tp_parallel_strategy( +def get_tp_parallel_strategy_for_transformer_block( job_config: JobConfig, model: nn.Module, ) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]: @@ -346,13 +346,6 @@ def apply_tp( """Apply tensor parallelism.""" tp_mesh = world_mesh["tp"] - # Parallel styles used for transformer block linear weights and their - # inputs may be different for float8 linears - ( - rowwise_parallel_weight, - colwise_parallel_weight, - prepare_module_input, - ) = get_tp_parallel_strategy(job_config, model) loss_parallel = parallel_dims.loss_parallel_enabled # 1. Parallelize the embedding and shard its outputs (which are the first @@ -368,7 +361,7 @@ def apply_tp( output_layouts=Shard(1), ), "norm": SequenceParallel(), - "output": colwise_parallel_weight( + "output": ColwiseParallel( input_layouts=Shard(1), output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, @@ -376,6 +369,14 @@ def apply_tp( }, ) + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears + ( + rowwise_parallel_weight, + colwise_parallel_weight, + prepare_module_input, + ) = get_tp_parallel_strategy_for_transformer_block(job_config, model) + # Apply tensor + sequence parallelism to every transformer block # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension.