Skip to content

Commit

Permalink
[BE] Combine OptimizerWrapper and OptimizerContainer (#738)
Browse files Browse the repository at this point in the history
Combine `state_dict` and `load_state_dict` from OptimizerWrapper to
OptimizerContainer so that we only have one optimzier related class
Also, add `get_lr_scheduler_state` to SchedulersContainer when update
`lr_scheduler` at self.state
  • Loading branch information
mori360 authored Dec 20, 2024
1 parent d67f7f9 commit ba24697
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 118 deletions.
80 changes: 4 additions & 76 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,14 @@
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import init_logger, logger
from torchtitan.optimizer import (
OptimizersContainer,
OptimizersInBackwardContainer,
SchedulersContainer,
SchedulersInBackwardContainer,
)
from torchtitan.optimizer import OptimizersContainer, SchedulersContainer


class IntervalType(enum.Enum):
Expand Down Expand Up @@ -104,43 +97,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
list(map(func, self.model))


class OptimizerWrapper(Stateful):
def __init__(
self,
model: Union[nn.Module, List[nn.Module]],
optim: OptimizersContainer,
) -> None:
self.model = [model] if isinstance(model, nn.Module) else model
if isinstance(optim, OptimizersInBackwardContainer):
self.optim = [
sub_optim
for optim_group in optim.optimizers
for sub_optim in optim_group
]
else:
optimizers = optim.optimizers
self.optim = (
[optimizers]
if isinstance(optimizers, torch.optim.Optimizer)
else optimizers
)

def state_dict(self) -> Dict[str, Any]:
func = functools.partial(
get_optimizer_state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
return {k: v for sd in map(func, self.model, self.optim) for k, v in sd.items()}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
func = functools.partial(
set_optimizer_state_dict,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
list(map(func, self.model, self.optim))


class Terminate:
pass

Expand Down Expand Up @@ -204,7 +160,7 @@ def __init__(
restore its optimizer states, others will error.
The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan
by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerWrapper.
by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer.
2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also
requiring us to reason about multiple 'optim' objects locally.
Expand All @@ -220,44 +176,16 @@ def __init__(
TODO: This is currently unsolved and needs a fix.
"""
assert len(model_parts) == len(
optimizers.optimizers
), "Must pass one optimizer per model part"
assert len(model_parts) == len(
lr_schedulers.schedulers
), "Must pass one lr_scheduler per model part"

self.states = states

self.states.update(
{
"model": ModelWrapper(model_parts),
"optimizer": OptimizerWrapper(
model_parts,
optimizers,
),
"optimizer": optimizers,
"dataloader": dataloader,
}
)
# SchedulersInBackwardContainer has a different structure than SchedulersContainer, List[List[Scheduler]] rahter
# than List[Scheduler], but the schedulers are the same for each list inside, so here just store the first one.
# TODO: Restructure SchedulersInBackwardContainer to be consisitent with SchedulersContainer.
if isinstance(lr_schedulers, SchedulersInBackwardContainer):
if len(lr_schedulers.schedulers) == 1:
self.states["lr_scheduler"] = lr_schedulers.schedulers[0][0]
else:
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks
for idx, lr_scheduler in enumerate(lr_schedulers.schedulers):
self.states[f"lr_scheduler_{idx}"] = lr_scheduler[0]
else:
if len(lr_schedulers.schedulers) == 1:
self.states["lr_scheduler"] = lr_schedulers.schedulers[0]
else:
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks
for idx, lr_scheduler in enumerate(lr_schedulers.schedulers):
self.states[f"lr_scheduler_{idx}"] = lr_scheduler
self.states.update(lr_schedulers.get_lr_scheduler_state())

self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval_type = (
Expand Down
119 changes: 77 additions & 42 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,31 @@
# LICENSE file in the root directory of this source tree.

import functools
from typing import Any, Dict, List

import torch
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import (
get_optimizer_state_dict,
set_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.optim.lr_scheduler import LambdaLR
from torchtitan.config_manager import JobConfig


class OptimizersContainer:
"""Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages"""
class OptimizersContainer(Stateful):
"""Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages
and saving/loading optimizer state_dict at checkpoint.
"""

def __init__(self, model_parts, optimizer_kwargs, name):
def __init__(
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str
) -> None:
self.optimizers = []
for model in model_parts:
self.model_parts = model_parts
for model in self.model_parts:
if name == "Adam":
# TODO: make the optimizer options configurable by toml/cmd args
optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs)
Expand All @@ -25,22 +38,50 @@ def __init__(self, model_parts, optimizer_kwargs, name):
else:
raise NotImplementedError(f"Optimizer {name} not added.")
self.optimizers.append(optimizer)
self._validate_length(len(self.model_parts))

def _validate_length(self, expected_length) -> None:
assert expected_length == len(
self.optimizers
), "Must pass one optimizer per model part or per param if using OptimizersInBackwardContainer"

def step(self):
def step(self) -> None:
for optimizer in self.optimizers:
optimizer.step()

def zero_grad(self):
def zero_grad(self) -> None:
for optimizer in self.optimizers:
optimizer.zero_grad()

def state_dict(self) -> Dict[str, Any]:
func = functools.partial(
get_optimizer_state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
return {
k: v
for sd in map(func, self.model_parts, self.optimizers)
for k, v in sd.items()
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
func = functools.partial(
set_optimizer_state_dict,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
list(map(func, self.model_parts, self.optimizers))


class OptimizersInBackwardContainer(OptimizersContainer):
"""Optimiers in backward to skip .step() and .zero_grad()"""

def __init__(self, model_parts, optimizer_kwargs, name):
def __init__(
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str
) -> None:
self.optimizers = []
for model in model_parts:
self.model_parts = model_parts
for model in self.model_parts:
if name == "Adam":
# TODO: make the optimizer options configurable by toml/cmd args
optim_dict = {
Expand All @@ -63,17 +104,25 @@ def optim_hook(param) -> None:
if param.requires_grad:
param.register_post_accumulate_grad_hook(optim_hook)

self.optimizers.append([optim_dict[param] for param in model.parameters()])
self.optimizers.extend([optim_dict[param] for param in model.parameters()])
self._validate_length(
sum(
len([param for param in model.parameters()])
for model in self.model_parts
)
)

def step(self):
def step(self) -> None:
pass

def zero_grad(self):
def zero_grad(self) -> None:
pass


# consider split between PP and non-PP
def build_optimizers(model_parts, job_config: JobConfig):
def build_optimizers(
model_parts: List[nn.Module], job_config: JobConfig
) -> OptimizersContainer:
"""Wrap one optimizer per model part in an OptimizersContainer which provides a single
step() and zero_grad() method for all the child optimizers.
"""
Expand Down Expand Up @@ -121,44 +170,30 @@ def linear_warmup_linear_decay(
class SchedulersContainer:
"""Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages"""

def __init__(self, optimizers, lr_lambda):
def __init__(self, optimizers, lr_lambda) -> None:
self.schedulers = []
for optimizer in optimizers:
self.schedulers.append(LambdaLR(optimizer, lr_lambda=lr_lambda))

def step(self):
for schedulers in self.schedulers:
schedulers.step()
def step(self) -> None:
for scheduler in self.schedulers:
scheduler.step()

def get_lr_scheduler_state(self) -> Dict[str, Any]:
state_dict = {}
if len(self.schedulers) == 1:
state_dict["lr_scheduler"] = self.schedulers[0]
else:
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks
for idx, lr_scheduler in enumerate(self.schedulers):
state_dict[f"lr_scheduler_{idx}"] = lr_scheduler
return state_dict

class SchedulersInBackwardContainer(SchedulersContainer):
"""Util for calling step on multiple learning rate schedulers when optimizers are in backward"""

def __init__(self, optimizers, lr_lambda):
# all the schedulers for each optimizer group are the same, here we only store the first one
# to self.schedulers follow the same structure as SchedulersContainer, but store all of them
# to self.all_schedulers for container.step() to call
self.schedulers = []
for optim_group in optimizers:
scheduler_group = []
for sub_optim in optim_group:
scheduler_group.append(LambdaLR(sub_optim, lr_lambda=lr_lambda))
self.schedulers.append(scheduler_group)

def step(self):
for scheduler_group in self.schedulers:
for scheduler in scheduler_group:
scheduler.step()


def build_lr_schedulers(optimizers, job_config: JobConfig):
optim_in_bwd = job_config.optimizer.early_step_in_backward
def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer:
warmup_steps = int(job_config.training.warmup_steps)
decay_steps = float(max(1, job_config.training.steps - warmup_steps))
lr_lambda = functools.partial(linear_warmup_linear_decay, warmup_steps, decay_steps)

return (
SchedulersContainer(optimizers, lr_lambda)
if not optim_in_bwd
else SchedulersInBackwardContainer(optimizers, lr_lambda)
)
return SchedulersContainer(optimizers, lr_lambda)

0 comments on commit ba24697

Please sign in to comment.