Skip to content

Commit

Permalink
use SequenceParallel style in tp/sp (pytorch#133)
Browse files Browse the repository at this point in the history
simplify things given we already have SequenceParallel style landed in
main
  • Loading branch information
wanchaol authored Mar 13, 2024
1 parent 7cd2725 commit 3161ffb
Showing 1 changed file with 6 additions and 42 deletions.
48 changes: 6 additions & 42 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,8 @@
from collections import defaultdict

import torch
from torch.distributed._tensor import (
distribute_module,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.distributed._tensor import Replicate, Shard

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
Expand All @@ -30,40 +25,13 @@
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
)
from torchtrain.config_manager import JobConfig
from torchtrain.logging_utils import logger
from torchtrain.meta_init import meta_to_real_init_fn


def distribute_rmsnorm(module, device_mesh):
# temp sharding API until PTD API is added
def prepare_input_fn(mod, inputs, device_mesh):
if isinstance(inputs[0], DTensor):
return inputs
elif isinstance(inputs[0], torch.Tensor):
shard_tensor = DTensor.from_local(
inputs[0], device_mesh, [Shard(0)], run_check=False
)
return shard_tensor
else:
raise NotImplementedError("!!")

def partition_fn(name, module, device_mesh):
for name, param in module.named_parameters():
dist_param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, [Replicate()])
)
module.register_parameter(name, dist_param)

return distribute_module(
module,
device_mesh,
partition_fn,
input_fn=prepare_input_fn,
)


# for selective AC
no_recompute_list = {
torch.ops.aten.mm.default,
Expand Down Expand Up @@ -145,6 +113,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
else Replicate(),
use_local_output=not parallel_dims.loss_parallel_enabled,
),
"norm": SequenceParallel(sequence_dim=0),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),
desired_input_layouts=(Shard(0), None),
Expand All @@ -153,9 +122,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
},
)

# shard the RMSNorm layer before last linear proj layer
distribute_rmsnorm(model.norm, tp_mesh)

# apply sequence parallelism to every transformer block
for layer_id, transformer_block in enumerate(model.layers):
layer_plan = {
Expand All @@ -167,24 +133,22 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(0)),
"attention_norm": SequenceParallel(sequence_dim=0),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(0),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(0)),
"feed_forward.w3": ColwiseParallel(),
"ffn_norm": SequenceParallel(sequence_dim=0),
}

# adjust num_heads in attention layer to local heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // sp_degree
attn_layer.n_kv_heads = attn_layer.n_kv_heads // sp_degree

# shard RMSNorm layers
distribute_rmsnorm(transformer_block.attention_norm, tp_mesh)
distribute_rmsnorm(transformer_block.ffn_norm, tp_mesh)

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
Expand Down

0 comments on commit 3161ffb

Please sign in to comment.