Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use SequenceParallel style in tp/sp #133

Merged
merged 2 commits into from
Mar 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 6 additions & 42 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,8 @@
from collections import defaultdict

import torch
from torch.distributed._tensor import (
distribute_module,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.distributed._tensor import Replicate, Shard

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
Expand All @@ -30,40 +25,13 @@
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
)
from torchtrain.config_manager import JobConfig
from torchtrain.logging_utils import logger
from torchtrain.meta_init import meta_to_real_init_fn


def distribute_rmsnorm(module, device_mesh):
# temp sharding API until PTD API is added
def prepare_input_fn(mod, inputs, device_mesh):
if isinstance(inputs[0], DTensor):
return inputs
elif isinstance(inputs[0], torch.Tensor):
shard_tensor = DTensor.from_local(
inputs[0], device_mesh, [Shard(0)], run_check=False
)
return shard_tensor
else:
raise NotImplementedError("!!")

def partition_fn(name, module, device_mesh):
for name, param in module.named_parameters():
dist_param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, [Replicate()])
)
module.register_parameter(name, dist_param)

return distribute_module(
module,
device_mesh,
partition_fn,
input_fn=prepare_input_fn,
)


# for selective AC
no_recompute_list = {
torch.ops.aten.mm.default,
Expand Down Expand Up @@ -145,6 +113,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
else Replicate(),
use_local_output=not parallel_dims.loss_parallel_enabled,
),
"norm": SequenceParallel(sequence_dim=0),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),
desired_input_layouts=(Shard(0), None),
Expand All @@ -153,9 +122,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
},
)

# shard the RMSNorm layer before last linear proj layer
distribute_rmsnorm(model.norm, tp_mesh)

# apply sequence parallelism to every transformer block
for layer_id, transformer_block in enumerate(model.layers):
layer_plan = {
Expand All @@ -167,24 +133,22 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(0)),
"attention_norm": SequenceParallel(sequence_dim=0),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(0),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(0)),
"feed_forward.w3": ColwiseParallel(),
"ffn_norm": SequenceParallel(sequence_dim=0),
}

# adjust num_heads in attention layer to local heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // sp_degree
attn_layer.n_kv_heads = attn_layer.n_kv_heads // sp_degree

# shard RMSNorm layers
distribute_rmsnorm(transformer_block.attention_norm, tp_mesh)
distribute_rmsnorm(transformer_block.ffn_norm, tp_mesh)

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
Expand Down
Loading