Skip to content

Commit

Permalink
add float8 specific parallel strategies
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Mar 20, 2024
1 parent 4fa7bbf commit 85b9e97
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# llama model, i.e. activation checkpointing, etc.

from collections import defaultdict
from typing import Tuple

import torch
from torch.distributed._tensor import Replicate, Shard
Expand Down Expand Up @@ -116,6 +117,23 @@ def selective_checkpointing_context_fn():
)


def get_tp_parallel_strategy(
job_config: JobConfig,
) -> Tuple[RowwiseParallel, ColwiseParallel]:
"""Get the parallel strategy for the transformer model.
This function handles the special case of using float8 with tensor parallelism.
"""
if job_config.training.fp8_linear == "dynamic":
from float8_experimental.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
)

return Float8RowwiseParallel, Float8ColwiseParallel
return RowwiseParallel, ColwiseParallel


def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply parallelisms to the model, including PTD parallelisms, and AC.
Expand All @@ -132,6 +150,10 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
tp_mesh = world_mesh["sp"]
sp_degree = job_config.training.sequence_parallel_degree

row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
job_config
)

# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. shard the first layer of transformer block
Expand All @@ -142,7 +164,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"embeddings.tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
),
"output": ColwiseParallel(
"output": col_parallel_strategy(
input_layouts=Shard(0),
output_layouts=Shard(-1)
if parallel_dims.loss_parallel_enabled
Expand All @@ -165,18 +187,18 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
input_layouts=(Shard(0), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(0)),
"attention.wq": col_parallel_strategy(),
"attention.wk": col_parallel_strategy(),
"attention.wv": col_parallel_strategy(),
"attention.wo": row_parallel_strategy(output_layouts=Shard(0)),
"attention_norm": SequenceParallel(sequence_dim=0),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(0),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(0)),
"feed_forward.w3": ColwiseParallel(),
"feed_forward.w1": col_parallel_strategy(),
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(0)),
"feed_forward.w3": col_parallel_strategy(),
"ffn_norm": SequenceParallel(sequence_dim=0),
}

Expand Down

0 comments on commit 85b9e97

Please sign in to comment.