diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 38014e53b..9067ae85a 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -5,6 +5,7 @@ # llama model, i.e. activation checkpointing, etc. from collections import defaultdict +from typing import Tuple import torch from torch.distributed._tensor import Replicate, Shard @@ -116,6 +117,23 @@ def selective_checkpointing_context_fn(): ) +def get_tp_parallel_strategy( + job_config: JobConfig, +) -> Tuple[RowwiseParallel, ColwiseParallel]: + """Get the parallel strategy for the transformer model. + + This function handles the special case of using float8 with tensor parallelism. + """ + if job_config.training.fp8_linear == "dynamic": + from float8_experimental.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + ) + + return Float8RowwiseParallel, Float8ColwiseParallel + return RowwiseParallel, ColwiseParallel + + def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): """ Apply parallelisms to the model, including PTD parallelisms, and AC. @@ -132,6 +150,10 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): tp_mesh = world_mesh["sp"] sp_degree = job_config.training.sequence_parallel_degree + row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy( + job_config + ) + # First: # 1. parallelize the first embedding and the last linear proj layer # 2. shard the first layer of transformer block @@ -142,7 +164,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): "embeddings.tok_embeddings": RowwiseParallel( input_layouts=Replicate(), ), - "output": ColwiseParallel( + "output": col_parallel_strategy( input_layouts=Shard(0), output_layouts=Shard(-1) if parallel_dims.loss_parallel_enabled @@ -165,18 +187,18 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): input_layouts=(Shard(0), None), desired_input_layouts=(Replicate(), None), ), - "attention.wq": ColwiseParallel(), - "attention.wk": ColwiseParallel(), - "attention.wv": ColwiseParallel(), - "attention.wo": RowwiseParallel(output_layouts=Shard(0)), + "attention.wq": col_parallel_strategy(), + "attention.wk": col_parallel_strategy(), + "attention.wv": col_parallel_strategy(), + "attention.wo": row_parallel_strategy(output_layouts=Shard(0)), "attention_norm": SequenceParallel(sequence_dim=0), "feed_forward": PrepareModuleInput( input_layouts=(Shard(0),), desired_input_layouts=(Replicate(),), ), - "feed_forward.w1": ColwiseParallel(), - "feed_forward.w2": RowwiseParallel(output_layouts=Shard(0)), - "feed_forward.w3": ColwiseParallel(), + "feed_forward.w1": col_parallel_strategy(), + "feed_forward.w2": row_parallel_strategy(output_layouts=Shard(0)), + "feed_forward.w3": col_parallel_strategy(), "ffn_norm": SequenceParallel(sequence_dim=0), }