Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add float8 specific parallel strategies #153

Merged
merged 1 commit into from
Mar 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# llama model, i.e. activation checkpointing, etc.

from collections import defaultdict
from typing import Tuple

import torch
from torch.distributed._tensor import Replicate, Shard
Expand Down Expand Up @@ -116,6 +117,23 @@ def selective_checkpointing_context_fn():
)


def get_tp_parallel_strategy(
job_config: JobConfig,
) -> Tuple[RowwiseParallel, ColwiseParallel]:
"""Get the parallel strategy for the transformer model.

This function handles the special case of using float8 with tensor parallelism.
"""
if job_config.training.fp8_linear == "dynamic":
from float8_experimental.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
)

return Float8RowwiseParallel, Float8ColwiseParallel
return RowwiseParallel, ColwiseParallel


def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply parallelisms to the model, including PTD parallelisms, and AC.
Expand All @@ -132,6 +150,10 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
tp_mesh = world_mesh["sp"]
sp_degree = job_config.training.sequence_parallel_degree

row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
job_config
)

# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. shard the first layer of transformer block
Expand All @@ -142,7 +164,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"embeddings.tok_embeddings": RowwiseParallel(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notable leave this untouched due to:

rank0]:[rank0]:     raise ValueError(
[rank0]:[rank0]: ValueError: Expecting module to be Float8DynamicLinear but found <class 'torch.nn.modules.sparse.Embedding'>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, that make sense because ColwiseParallel and RowwiseParallel in TP not only works for nn.Linear but also nn.Embedding

input_layouts=Replicate(),
),
"output": ColwiseParallel(
"output": col_parallel_strategy(
input_layouts=Shard(0),
output_layouts=Shard(-1)
if parallel_dims.loss_parallel_enabled
Expand All @@ -165,18 +187,18 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
input_layouts=(Shard(0), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(0)),
"attention.wq": col_parallel_strategy(),
"attention.wk": col_parallel_strategy(),
"attention.wv": col_parallel_strategy(),
"attention.wo": row_parallel_strategy(output_layouts=Shard(0)),
"attention_norm": SequenceParallel(sequence_dim=0),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(0),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(0)),
"feed_forward.w3": ColwiseParallel(),
"feed_forward.w1": col_parallel_strategy(),
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(0)),
"feed_forward.w3": col_parallel_strategy(),
"ffn_norm": SequenceParallel(sequence_dim=0),
}

Expand Down
Loading