Skip to content

Commit

Permalink
Add support of DDP and CompiledAutograd.
Browse files Browse the repository at this point in the history
ghstack-source-id: f4b9c10f8dc61f5640176f25213bbcd0fbe6ce97
Pull Request resolved: #319
  • Loading branch information
fegin committed May 9, 2024
1 parent f5a3ad7 commit aa08c80
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 36 deletions.
27 changes: 25 additions & 2 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,20 @@ def __init__(self):
"--training.data_parallel_degree",
type=int,
default=-1,
help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.",
help="Data Parallelism degree (FSDP). -1 means leftover ranks will be used (After SP/PP/replicate). 1 means disabled.",
)
self.parser.add_argument(
"--training.data_parallel_replicate_degree",
type=int,
default=1,
help="""
Data Parallelism with parameters being replicated degree. 1 means disabled.
If data_parallel_degree is > 1 and data_parallel_replicate_degree > 1,
the parallelism is HSDP. HSDP is not yet neabled and but will be supported soon.
When data_parallel_degree is -1 and data_parallel_replicate_degree > 1,
the parallelism is DDP. DDP should only be used for small model as
DDP + TP is not yet supported.
"""
)
self.parser.add_argument(
"--training.tensor_parallel_degree",
Expand All @@ -210,7 +223,17 @@ def __init__(self):
self.parser.add_argument(
"--training.compile",
action="store_true",
help="Whether to compile the model",
help="Whether to compile the model.",
)
self.parser.add_argument(
"--training.compiled_autograd",
action="store_true",
help=
"""
Whether to use CompiledAutograd to trace the backward.
This is an experimental feature and should not be used
unless you are familiar with CompiledAutograd.
"""
)
self.parser.add_argument(
"--training.fp8_linear",
Expand Down
21 changes: 16 additions & 5 deletions torchtitan/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
@dataclass
class ParallelDims:
dp: int
dp_replicate: int
tp: int
pp: int
world_size: int
Expand All @@ -29,21 +30,27 @@ def __post_init__(self):
self._validate()

def _validate(self):
dp, tp, pp = self.dp, self.tp, self.pp
dp, dp_replicate, tp, pp = self.dp, self.dp_replicate, self.tp, self.pp
if dp == -1:
self.dp = dp = self.world_size // (tp * pp)
self.dp = dp = self.world_size // (dp_replicate * tp * pp)
assert dp >= 1, dp
assert dp_replicate >= 1, dp_replicate
assert tp >= 1, tp
assert pp >= 1, pp
assert (
dp * tp * pp == self.world_size
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
dp * dp_replicate * tp * pp == self.world_size
), (
f"Invalid parallel dims: dp({dp}) * dp_replicate({dp_replicate}) * "
f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})."
)

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
[self.pp, self.dp_replicate, self.dp, self.tp],
["pp", "dp_replicate", "dp", "tp"],
strict=True
):
if d > 1:
dims.append(d)
Expand All @@ -56,6 +63,10 @@ def build_mesh(self, device_type):
def dp_enabled(self):
return self.dp > 1

@property
def dp_replicate_enabled(self):
return self.dp_replicate > 1

@property
def tp_enabled(self):
return self.tp > 1
Expand Down
86 changes: 61 additions & 25 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from typing import Tuple

import torch
import torch.nn as nn

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._composable.replicate import replicate
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
Expand Down Expand Up @@ -129,7 +131,56 @@ def get_tp_parallel_strategy(
return RowwiseParallel, ColwiseParallel


def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
def maybe_enable_activation_checkpoint(
model: nn.Module, job_config: JobConfig
) -> nn.Module:
config = job_config.activation_checkpoint
ac_mode = config.mode
if ac_mode in ("full", "selective"):
for layer_id, transformer_block in enumerate(model.layers):
model.layers[layer_id] = checkpoint_wrapper(transformer_block, config)
logger.info(f"Applied {ac_mode} activation checkpointing to the model")

return model


def enable_fsdp(model: nn.Module, dp_mesh, job_config: JobConfig) -> nn.Module:
# TODO: Expose `reduce_dtype` as a config option.
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in enumerate(model.layers):
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = layer_id < len(model.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block
model = fully_shard(model, **fsdp_config)
logger.info("Applied FSDP to the model")

return model


def enable_ddp(model: nn.Module, dp_mesh, job_config: JobConfig) -> nn.Module:
if job_config.training.compile:
if job_config.training.compiled_autograd:
torch._dynamo.config.optimize_ddp = "python_reducer"
else:
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
model = replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
logger.info("Applied DDP to the model")

return model


def parallelize_llama(
model: nn.Module, world_mesh, parallel_dims, job_config: JobConfig
) -> nn.Module:
"""
Apply parallelisms and activation checkpointing to the model.
Expand All @@ -144,6 +195,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
raise NotImplementedError(
"fused_rmsnorm not yet compatible with TP. Please use layernorm or rmsnorm."
)
if parallel_dims.dp_replicate_enabled:
raise NotImplementedError("DDP/HSDP + TP are not supported yet.")

tp_mesh = world_mesh["tp"]
row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
Expand Down Expand Up @@ -206,32 +259,15 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):

logger.info("Applied Tensor Parallelism to the model")

model = maybe_enable_activation_checkpoint(model, job_config)
if parallel_dims.dp_enabled:
if parallel_dims.dp_replicate_enabled:
raise NotImplementedError("HSDP is not supported yet.")
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
# TODO: Expose `reduce_dtype` as a config option.
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
)
ac_mode = job_config.activation_checkpoint.mode
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in enumerate(model.layers):
if job_config.activation_checkpoint.mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = layer_id < len(model.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block
model = fully_shard(model, **fsdp_config)
if ac_mode in ("full", "selective"):
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
logger.info("Applied FSDP to the model")
model = enable_fsdp(model, dp_mesh, job_config)
elif parallel_dims.dp_replicate_enabled:
dp_mesh = world_mesh["dp_replicate"] if world_mesh.ndim > 1 else world_mesh
model = enable_ddp(model, dp_mesh, job_config)

return model
12 changes: 8 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def main(job_config: JobConfig):
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.training.pipeline_parallel_degree,
world_size=world_size,
Expand Down Expand Up @@ -303,10 +304,13 @@ def loss_fn(pred, labels):
optimizer.zero_grad()

# forward / backward
with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward()
with torch._dynamo.utils.maybe_enable_compiled_autograd(
job_config.training.compiled_autograd
):
with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward()

# clip gradients
torch.nn.utils.clip_grad_norm_(
Expand Down
40 changes: 40 additions & 0 deletions train_configs/llama_1b.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# TorchTrain Config.toml
[job]
dump_folder = "./outputs"
description = "LLaMA 1B training"

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"

[model]
name = "llama2"
flavor = "1B"
norm_type = "fused_rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 1.5e-4

[training]
batch_size = 8
seq_len = 1024
warmup_steps = 200 # lr scheduler warm up
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4"

[activation_checkpoint]
mode = "none" # ['none', 'full', 'selective']

0 comments on commit aa08c80

Please sign in to comment.