Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Renamed parallel styles for transformer block weights #448

Merged
merged 7 commits into from
Jul 11, 2024
28 changes: 15 additions & 13 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""

tp_mesh = world_mesh["tp"]
# Parallel styles for transformer block linear weights may be different for
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are only used for transformer block weights, so we can rename this to differentiate from the ColwiseParallel etc. for the other parameters.

# float8 linears
(
row_parallel_strategy,
col_parallel_strategy,
prepare_module_input,
rowwise_parallel_weight,
colwise_parallel_weight,
prepare_weight_input,
) = get_tp_parallel_strategy(job_config)
loss_parallel = parallel_dims.loss_parallel_enabled

Expand All @@ -318,7 +320,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": col_parallel_strategy(
"output": colwise_parallel_weight(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
Expand All @@ -333,22 +335,22 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention_norm": SequenceParallel(),
"attention": prepare_module_input(
"attention": prepare_weight_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": col_parallel_strategy(),
"attention.wk": col_parallel_strategy(),
"attention.wv": col_parallel_strategy(),
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
"attention.wq": colwise_parallel_weight(),
"attention.wk": colwise_parallel_weight(),
"attention.wv": colwise_parallel_weight(),
"attention.wo": rowwise_parallel_weight(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": prepare_module_input(
"feed_forward": prepare_weight_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": col_parallel_strategy(),
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
"feed_forward.w3": col_parallel_strategy(),
"feed_forward.w1": colwise_parallel_weight(),
"feed_forward.w2": rowwise_parallel_weight(output_layouts=Shard(1)),
"feed_forward.w3": colwise_parallel_weight(),
}

# Adjust attention module to use the local number of heads
Expand Down
Loading