Skip to content

Commit

Permalink
Update on "some compile-related improvements"
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
tianyu-l committed Jul 10, 2024
2 parents dc41936 + 7838b6f commit 96d467f
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 84 deletions.
6 changes: 5 additions & 1 deletion estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,12 @@ def estimate_memory(job_config: JobConfig):
)
job_config.model.norm_type = "rmsnorm"

if job_config.model.norm_type == "compiled_rmsnorm":
logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.")
job_config.model.norm_type = "rmsnorm"

if job_config.training.compile:
logger.info("Compile mode is not supported yet. " "Switching to Eager mode.")
logger.info("Compile mode is not supported yet. Switching to eager mode.")
job_config.training.compile = False

parallel_dims = ParallelDims(
Expand Down
6 changes: 4 additions & 2 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def build_test_list():
OverrideDefinitions(
[
[
"--training.compile --model.norm_type=rmsnorm --activation_checkpoint.selective_ac_option=op",
"--training.compile",
"--activation_checkpoint.mode selective",
"--activation_checkpoint.selective_ac_option op",
],
],
"1D compile with selective op AC",
Expand Down Expand Up @@ -275,7 +277,7 @@ def build_test_list():
OverrideDefinitions(
[
[
"--memory_estimation.enabled",
"--memory_estimation.enabled --model.norm_type rmsnorm",
]
],
"FSDP2 Memory Tracking and Estimation",
Expand Down
9 changes: 5 additions & 4 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(self):
"--model.norm_type",
type=str,
default="rmsnorm",
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]",
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, compiled_rmsnorm, fused_rmsnorm]",
)
self.parser.add_argument(
"--model.tokenizer_path",
Expand Down Expand Up @@ -532,10 +532,11 @@ def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
args_dict[first_level_key][second_level_key] = v
return args_dict

def _validate_config(self) -> bool:
def _validate_config(self) -> None:
# TODO: Add more mandatory validations
assert self.model.name and self.model.flavor and self.model.tokenizer_path
return True
assert self.model.name
assert self.model.flavor
assert self.model.tokenizer_path

def parse_args_from_command_line(
self, args_list
Expand Down
21 changes: 16 additions & 5 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
elif norm_type == "rmsnorm":
return RMSNorm(dim, eps=eps)
elif norm_type == "compiled_rmsnorm":
return RMSNorm(dim, eps=eps, compile=True)
elif norm_type == "fused_rmsnorm":
return FusedRMSNorm(dim, eps=eps)
else:
Expand Down Expand Up @@ -87,17 +89,26 @@ class RMSNorm(nn.Module):
"""

def __init__(self, dim: int, eps: float = 1e-6):
def __init__(self, dim: int, eps: float = 1e-6, compile: bool = False):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.rmsnorm_fn = (
torch.compile(self.compute_rmsnorm, fullgraph=True)
if compile
else self.compute_rmsnorm
)

@staticmethod
def compute_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float):
def _norm(x, eps):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)

def _norm(self, x: torch.Tensor):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
output = _norm(x.float(), eps).type_as(x)
return output * weight

def forward(self, x: torch.Tensor):
output = self._norm(x.float()).type_as(x)
return output * self.weight
return self.rmsnorm_fn(x, self.weight, self.eps)

def reset_parameters(self):
torch.nn.init.ones_(self.weight) # type: ignore
Expand Down
115 changes: 49 additions & 66 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# this file applies the PTD parallelisms and various training techniques to the
# llama model, i.e. activation checkpointing, etc.
# This file applies the PT-D parallelisms and various training techniques (e.g.
# activation checkpointing and compile) to the Llama model.

import copy
from collections import defaultdict
Expand All @@ -17,7 +17,6 @@
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint
from torch.distributed.tensor.parallel import (
Expand All @@ -28,8 +27,6 @@
SequenceParallel,
)

from torch.utils.checkpoint import checkpoint

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import logger
from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank
Expand All @@ -43,10 +40,25 @@
}


# Uses PTD FSDP AC wrapper
# currently selective per op and per layer checkpointing are supported
def checkpoint_wrapper(module, config):
if config.mode == "selective" and config.selective_ac_option == "op":
def checkpoint_wrapper(module: torch.nn.Module, ac_config):
valid_ac_modes = ("full", "selective")
if ac_config.mode not in valid_ac_modes:
raise ValueError(
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
)

if ac_config.mode == "full":
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)

assert ac_config.mode == "selective", f"{ac_config.mode}"
use_op_sac = ac_config.selective_ac_option == "op"
use_layer_sac = ac_config.selective_ac_option.isdigit()
if not use_op_sac and not use_layer_sac:
raise ValueError(
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
f"Valid options: 'op' or a positive int representing layer frequency"
)
if use_op_sac:
from torch.utils.checkpoint import (
CheckpointPolicy,
create_selective_checkpoint_contexts,
Expand Down Expand Up @@ -76,53 +88,23 @@ def selective_checkpointing_context_fn():

return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
checkpoint_fn=checkpoint,
context_fn=selective_checkpointing_context_fn,
use_reentrant=False,
preserve_rng_state=False,
)
elif config.mode == "full":
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
checkpoint_fn=checkpoint,
use_reentrant=False,
preserve_rng_state=False,
)

elif config.mode == "selective" and config.selective_ac_option.isdigit():
"""enables selective checkpointing of candidate layers.
Usage:
'selective_ac_option' with a positive 'int' value in config controls which layers to checkpoint.
1 == checkpointing every one (all).
2 == checkpoint every 2nd one
"""
ac_freq = int(config.selective_ac_option)
assert (
ac_freq >= 0
), f"selective layer AC policy (ac_freq) expects a positive integer, received {ac_freq}"

checkpoint_wrapper.__dict__.setdefault("_count", 0)

checkpoint_wrapper._count += 1
if not ac_freq or checkpoint_wrapper._count % ac_freq == 0:
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
checkpoint_fn=checkpoint,
use_reentrant=False,
preserve_rng_state=False,
elif use_layer_sac:
# Checkpoint every `ac_freq` of the modules passed to this function
ac_freq = int(ac_config.selective_ac_option)
if ac_freq <= 0:
raise ValueError(
f"Selective layer AC expects a positive int as selective_ac_option but got {ac_freq}"
)
# skip activation checkpointing and store activations for this layer
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
ptd_checkpoint_wrapper._count += 1
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
else:
return module

else:
raise NotImplementedError(
"Unknown AC type or AC config. Only selective op and selective layer ac implemented currently."
)


def get_tp_parallel_strategy(
job_config: JobConfig,
Expand Down Expand Up @@ -341,9 +323,10 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
) = get_tp_parallel_strategy(job_config)
loss_parallel = parallel_dims.loss_parallel_enabled

# 1. Parallelize the first embedding and the last linear proj layer
# 1. Parallelize the embedding and shard its outputs (which are the first
# transformer block's inputs)
# 2. Parallelize the root norm layer over the sequence dim
# 3. Shard the first transformer block's inputs
# 3. Parallelize the final linear output layer
model = parallelize_module(
model,
tp_mesh,
Expand All @@ -352,12 +335,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": col_parallel_strategy(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
"norm": SequenceParallel(),
},
)

Expand All @@ -367,6 +350,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention_norm": SequenceParallel(),
"attention": prepare_module_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
Expand All @@ -375,15 +359,14 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
"attention.wk": col_parallel_strategy(),
"attention.wv": col_parallel_strategy(),
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"ffn_norm": SequenceParallel(),
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": col_parallel_strategy(),
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
"feed_forward.w3": col_parallel_strategy(),
"ffn_norm": SequenceParallel(),
}

# Adjust attention module to use the local number of heads
Expand Down Expand Up @@ -432,7 +415,8 @@ def apply_compile(model, job_config: JobConfig):
"fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm."
)

# NOTE(anijain): enable the following flag to accelarate compilation
# TODO(anijain): the following flag is on to accelarate compilation
# remove it after it's enabled in pytorch by default
torch._dynamo.config.inline_inbuilt_nn_modules = True

for layer_id, transformer_block in model.layers.named_children():
Expand All @@ -446,7 +430,7 @@ def apply_compile(model, job_config: JobConfig):

def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply data parallelism to the model. FSDP2 is used here.
Apply data parallelism (FSDP2) to the model.
"""

dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
Expand All @@ -459,21 +443,20 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}

for layer_id, transformer_block in model.layers.items():
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately.
# When using Pipeline Parallelism, generally zero-2 is best so as to avoid repeated reshardings
# per microbatch.
reshard_after_forward = (
int(layer_id) < len(model.layers) - 1 and not parallel_dims.pp_enabled
)
if parallel_dims.pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block

model = fully_shard(
fully_shard(
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled
)

Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ save_tb_folder = "tb"
[model]
name = "llama3"
flavor = "debugmodel"
norm_type = "fused_rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
norm_type = "compiled_rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
# test tokenizer.model, for debug purpose only
tokenizer_path = "./test/assets/test_tiktoken.model"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_13b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ save_tb_folder = "tb"
[model]
name = "llama2"
flavor = "13B"
norm_type = "fused_rmsnorm" # [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model"

[optimizer]
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ save_tb_folder = "tb"
[model]
name = "llama2"
flavor = "70B"
norm_type = "rmsnorm" # [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model"

[optimizer]
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_7b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ save_tb_folder = "tb"
[model]
name = "llama2"
flavor = "7B"
norm_type = "fused_rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model"

[optimizer]
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ save_tb_folder = "tb"
[model]
name = "llama3"
flavor = "70B"
norm_type = "rmsnorm" # [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ save_tb_folder = "tb"
[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
Expand Down

0 comments on commit 96d467f

Please sign in to comment.