diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 6c72b046..8d31ddd4 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -7,13 +7,8 @@ from collections import defaultdict import torch -from torch.distributed._tensor import ( - distribute_module, - distribute_tensor, - DTensor, - Replicate, - Shard, -) +from torch.distributed._tensor import Replicate, Shard + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, CheckpointImpl, @@ -30,40 +25,13 @@ parallelize_module, PrepareModuleInput, RowwiseParallel, + SequenceParallel, ) from torchtrain.config_manager import JobConfig from torchtrain.logging_utils import logger from torchtrain.meta_init import meta_to_real_init_fn -def distribute_rmsnorm(module, device_mesh): - # temp sharding API until PTD API is added - def prepare_input_fn(mod, inputs, device_mesh): - if isinstance(inputs[0], DTensor): - return inputs - elif isinstance(inputs[0], torch.Tensor): - shard_tensor = DTensor.from_local( - inputs[0], device_mesh, [Shard(0)], run_check=False - ) - return shard_tensor - else: - raise NotImplementedError("!!") - - def partition_fn(name, module, device_mesh): - for name, param in module.named_parameters(): - dist_param = torch.nn.Parameter( - distribute_tensor(param, device_mesh, [Replicate()]) - ) - module.register_parameter(name, dist_param) - - return distribute_module( - module, - device_mesh, - partition_fn, - input_fn=prepare_input_fn, - ) - - # for selective AC no_recompute_list = { torch.ops.aten.mm.default, @@ -145,6 +113,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): else Replicate(), use_local_output=not parallel_dims.loss_parallel_enabled, ), + "norm": SequenceParallel(sequence_dim=0), "layers.0": PrepareModuleInput( input_layouts=(Replicate(), None), desired_input_layouts=(Shard(0), None), @@ -153,9 +122,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): }, ) - # shard the RMSNorm layer before last linear proj layer - distribute_rmsnorm(model.norm, tp_mesh) - # apply sequence parallelism to every transformer block for layer_id, transformer_block in enumerate(model.layers): layer_plan = { @@ -167,6 +133,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): "attention.wk": ColwiseParallel(), "attention.wv": ColwiseParallel(), "attention.wo": RowwiseParallel(output_layouts=Shard(0)), + "attention_norm": SequenceParallel(sequence_dim=0), "feed_forward": PrepareModuleInput( input_layouts=(Shard(0),), desired_input_layouts=(Replicate(),), @@ -174,6 +141,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): "feed_forward.w1": ColwiseParallel(), "feed_forward.w2": RowwiseParallel(output_layouts=Shard(0)), "feed_forward.w3": ColwiseParallel(), + "ffn_norm": SequenceParallel(sequence_dim=0), } # adjust num_heads in attention layer to local heads @@ -181,10 +149,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): attn_layer.n_heads = attn_layer.n_heads // sp_degree attn_layer.n_kv_heads = attn_layer.n_kv_heads // sp_degree - # shard RMSNorm layers - distribute_rmsnorm(transformer_block.attention_norm, tp_mesh) - distribute_rmsnorm(transformer_block.ffn_norm, tp_mesh) - parallelize_module( module=transformer_block, device_mesh=tp_mesh,