@@ -117,7 +117,7 @@ def selective_checkpointing_context_fn():
117
117
return module
118
118
119
119
120
- def get_tp_parallel_strategy (
120
+ def get_tp_parallel_strategy_for_transformer_block (
121
121
job_config : JobConfig ,
122
122
model : nn .Module ,
123
123
) -> Tuple [RowwiseParallel , ColwiseParallel , PrepareModuleInput ]:
@@ -346,13 +346,6 @@ def apply_tp(
346
346
"""Apply tensor parallelism."""
347
347
348
348
tp_mesh = world_mesh ["tp" ]
349
- # Parallel styles used for transformer block linear weights and their
350
- # inputs may be different for float8 linears
351
- (
352
- rowwise_parallel_weight ,
353
- colwise_parallel_weight ,
354
- prepare_module_input ,
355
- ) = get_tp_parallel_strategy (job_config , model )
356
349
loss_parallel = parallel_dims .loss_parallel_enabled
357
350
358
351
# 1. Parallelize the embedding and shard its outputs (which are the first
@@ -368,14 +361,22 @@ def apply_tp(
368
361
output_layouts = Shard (1 ),
369
362
),
370
363
"norm" : SequenceParallel (),
371
- "output" : colwise_parallel_weight (
364
+ "output" : ColwiseParallel (
372
365
input_layouts = Shard (1 ),
373
366
output_layouts = Shard (- 1 ) if loss_parallel else Replicate (),
374
367
use_local_output = not loss_parallel ,
375
368
),
376
369
},
377
370
)
378
371
372
+ # Parallel styles used for transformer block linear weights and their
373
+ # inputs may be different for float8 linears
374
+ (
375
+ rowwise_parallel_weight ,
376
+ colwise_parallel_weight ,
377
+ prepare_module_input ,
378
+ ) = get_tp_parallel_strategy_for_transformer_block (job_config , model )
379
+
379
380
# Apply tensor + sequence parallelism to every transformer block
380
381
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
381
382
# by folding (and unfolding) the batch dimension and the sequence dimension.
0 commit comments