Skip to content

Commit

Permalink
Update on "[PP] add flexible interleaved 1f1b schedule"
Browse files Browse the repository at this point in the history
fixes #483


[ghstack-poisoned]
  • Loading branch information
H-Huang committed Jul 29, 2024
2 parents aac2d27 + 3840add commit 3b2c865
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 51 deletions.
28 changes: 17 additions & 11 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed import destroy_process_group
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.distributed.tensor.parallel import loss_parallel
from torch.testing._internal.distributed.fake_pg import FakeStore

from torchtitan.config_manager import JobConfig
from torchtitan.datasets import create_tokenizer
from torchtitan.float8_linear import build_fp8_linear
from torchtitan.float8_linear import (
maybe_build_fp8_linear,
maybe_precompute_fp8_dynamic_scale_for_fsdp,
)
from torchtitan.logging_utils import init_logger, logger
from torchtitan.lr_scheduling import get_lr_schedulers
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
from train import build_optimizers
from train import build_optimizers, get_train_context


def estimate_memory(job_config: JobConfig):
Expand Down Expand Up @@ -61,9 +63,10 @@ def estimate_memory(job_config: JobConfig):
logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.")
job_config.model.norm_type = "rmsnorm"

if job_config.training.compile:
if job_config.training.compile or job_config.experimental.enable_compiled_autograd:
logger.info("Compile mode is not supported yet. Switching to eager mode.")
job_config.training.compile = False
job_config.experimental.enable_compiled_autograd = False

parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
Expand Down Expand Up @@ -96,9 +99,9 @@ def estimate_memory(job_config: JobConfig):
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

# loss_parallel enables dispatching to efficient loss operators
loss_parallel_ctx = (
loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext
train_context = get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
)

# loss fn can be shared by pipeline-parallel or non-pp execution
Expand All @@ -124,9 +127,8 @@ def loss_fn(pred, labels):
with torch.device("meta"):
whole_model = model_cls.from_model_args(model_config)

# apply fp8 linear module swap
if job_config.training.enable_fp8_linear:
build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)
# swap to Float8Linear base on fp8 config
maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)

# apply PT-D DP/TP parallelisms and activation checkpointing
model_parts = [whole_model]
Expand Down Expand Up @@ -171,7 +173,7 @@ def loss_fn(pred, labels):
for iter_idx in range(2):
input_ids, labels = batch
# train step
with loss_parallel_ctx():
with train_context():
pred = whole_model(input_ids)
loss = loss_fn(pred, labels)
del pred
Expand All @@ -185,6 +187,10 @@ def loss_fn(pred, labels):
# optimizer step
optimizers.step()
lr_schedulers.step()
# when fp8 config is on,
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
maybe_precompute_fp8_dynamic_scale_for_fsdp(whole_model, job_config)
optimizers.zero_grad()
print(f"Peak Memory at iter: {iter_idx}")
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)
Expand Down
2 changes: 2 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):

for override_arg in test_flavor.override_args:
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh"
if test_name == "fsdp2_mem_tracker":
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_memory_estimation.sh"
cmd += " " + dump_folder_arg
cmd += " " + model_flavor_arg
if override_arg:
Expand Down
40 changes: 15 additions & 25 deletions torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
import contextlib
import functools
from typing import Optional

Expand All @@ -24,20 +23,6 @@
from torchtitan.logging_utils import logger


@contextlib.contextmanager
def set_enable_fsdp_float8_all_gather(enable_fsdp_fp8_all_gather: bool):
import float8_experimental.config as config

prev = config.enable_fsdp_fp8_all_gather
torch.distributed.barrier()
config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather
try:
yield
finally:
torch.distributed.barrier()
config.enable_fsdp_fp8_all_gather = prev


@functools.lru_cache(None)
def is_sm90_or_later():
# Float8 is only supported on H100+ GPUs
Expand All @@ -63,21 +48,26 @@ def maybe_build_fp8_linear(
)
return
try:
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
from float8_experimental import (
CastConfig,
convert_to_float8_training,
Float8LinearConfig,
ScalingType,
)

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
job_config.training.enable_fsdp_float8_all_gather and dp_enabled
)
with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather):
swap_linear_with_float8_linear(
model,
scaling_type_w=TensorScalingType.DYNAMIC,
skip_fqn_list=["output"],
)
float8_config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC),
)
convert_to_float8_training(
model,
config=float8_config,
module_filter_fn=lambda mod, fqn: fqn != "output",
)
logger.info(
f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}"
)
Expand All @@ -102,6 +92,6 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp(
"Skipped precomputing fp8 scales because SM90 or later is not available",
)
return
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
from float8_experimental import precompute_float8_dynamic_scale_for_fsdp

precompute_float8_dynamic_scale_for_fsdp(model)
30 changes: 15 additions & 15 deletions torchtitan/lr_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,43 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import functools

from torch.optim.lr_scheduler import LambdaLR
from torchtitan.config_manager import JobConfig

# global states for scheduling
# these are needed as LambdaLR does not support argument passing
_warmup_steps = 200
_decay_steps = 0


def linear_warmup_linear_decay(current_step: int) -> float:
def linear_warmup_linear_decay(
warmup_steps: int, decay_steps: int, current_step: int
) -> float:
"""Computes linear warmup followed by linear decay.
Per LambdaLR requirement, this is accomplished by returning
a multiplicative factor to adjust the learning rate to
create the desired schedule.
"""
if current_step < _warmup_steps:
if current_step < warmup_steps:
# linear warmup
# 0-indexed step, hence + 1 adjustments
current_step += 1
curr_adjustment = float(current_step / (_warmup_steps + 1))
curr_adjustment = float(current_step / (warmup_steps + 1))

else:
# linear decay
normalized_step = _decay_steps - (current_step - _warmup_steps)
curr_adjustment = 1 - (_decay_steps - normalized_step) / _decay_steps
normalized_step = decay_steps - (current_step - warmup_steps)
curr_adjustment = 1 - (decay_steps - normalized_step) / decay_steps

return curr_adjustment


def get_lr_schedulers(optimizers, job_config: JobConfig):
def _get_lr_scheduler(optimizer):
"""Build a linear warmup and linear decay scheduler"""
global _warmup_steps, _decay_steps
_warmup_steps = int(job_config.training.warmup_steps)
_decay_steps = float(max(1, job_config.training.steps - _warmup_steps))

warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay)
warmup_steps = int(job_config.training.warmup_steps)
decay_steps = float(max(1, job_config.training.steps - warmup_steps))
lr_lambda = functools.partial(
linear_warmup_linear_decay, warmup_steps, decay_steps
)
warmup_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
return warmup_scheduler

class SchedulersContainer:
Expand Down
11 changes: 11 additions & 0 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,17 @@ def apply_fsdp(
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled
)

if parallel_dims.pp_enabled:
# TODO
# This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since
# without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even
# without resharding, we load a seed-checkpoint and need to disable the safety mechanism. This hack should be
# removed after strided sharding is landed in DCP.
for module in model.modules():
assert len(module._load_state_dict_pre_hooks) <= 1
module._load_state_dict_pre_hooks.clear()
assert len(module._state_dict_pre_hooks) <= 1
module._state_dict_pre_hooks.clear()
logger.info("Applied FSDP to the model")
return model

Expand Down

0 comments on commit 3b2c865

Please sign in to comment.