From 6165d3d7e5e4fcd35ae7592580726b3322fdd1ae Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Wed, 10 Jul 2024 07:53:25 -0700 Subject: [PATCH 1/2] Reordered TP parallel plan to follow execution order [ghstack-poisoned] --- torchtitan/parallelisms/parallelize_llama.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 7becb731..c07d4c33 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -332,7 +332,6 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): """ Apply tensor parallelism. """ - tp_mesh = world_mesh["tp"] ( row_parallel_strategy, @@ -341,9 +340,10 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): ) = get_tp_parallel_strategy(job_config) loss_parallel = parallel_dims.loss_parallel_enabled - # 1. Parallelize the first embedding and the last linear proj layer + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) # 2. Parallelize the root norm layer over the sequence dim - # 3. Shard the first transformer block's inputs + # 3. Parallelize the final linear output layer model = parallelize_module( model, tp_mesh, @@ -352,12 +352,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): input_layouts=Replicate(), output_layouts=Shard(1), ), + "norm": SequenceParallel(), "output": col_parallel_strategy( input_layouts=Shard(1), output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, ), - "norm": SequenceParallel(), }, ) @@ -367,6 +367,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 for layer_id, transformer_block in model.layers.items(): layer_plan = { + "attention_norm": SequenceParallel(), "attention": prepare_module_input( input_layouts=(Shard(1), None), desired_input_layouts=(Replicate(), None), @@ -375,7 +376,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): "attention.wk": col_parallel_strategy(), "attention.wv": col_parallel_strategy(), "attention.wo": row_parallel_strategy(output_layouts=Shard(1)), - "attention_norm": SequenceParallel(), + "ffn_norm": SequenceParallel(), "feed_forward": prepare_module_input( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), @@ -383,7 +384,6 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): "feed_forward.w1": col_parallel_strategy(), "feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)), "feed_forward.w3": col_parallel_strategy(), - "ffn_norm": SequenceParallel(), } # Adjust attention module to use the local number of heads From 74304ba949fe260f9160380762c3b8f5d4fc6e0a Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Wed, 10 Jul 2024 08:05:38 -0700 Subject: [PATCH 2/2] Update on "Reordered TP parallel plan to follow execution order" - Llama uses pre-norm (norm before attention and before FFN), so we can move these up. - The root norm is before output, so we can swap this order too. [ghstack-poisoned] --- torchtitan/parallelisms/parallelize_llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index c07d4c33..32fbcc63 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -332,6 +332,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): """ Apply tensor parallelism. """ + tp_mesh = world_mesh["tp"] ( row_parallel_strategy,