Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] Lr schduler flatten #794

Merged
merged 9 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ def __init__(
"model": ModelWrapper(model_parts),
"optimizer": optimizers,
"dataloader": dataloader,
"lr_scheduler": lr_schedulers,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it won't be this simple. Both OptimizersContainer and ModelWrapper define state_dict and load_state_dict to handle flattening and unflattening. Since we don't have things like get_model_state_dict and set_model_state_dict for lr scheduler in torch.distributed.checkpoint.state_dict, we likely will need to manually write something for the LambdaLR we are using. See #738 (comment)

Let's work with @fegin on this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compared lr_schedulers before and after flattening, with/without checkpoint
lr_scheduler values are consistent with changes here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it support DCP resharding? e.g. PP degree from 2 to 4 across two jobs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this PR doesn't address the resharding issue, hence the [BE] prefix. Supporting lr resharding deserve a separate PR.

}
)
self.states.update(lr_schedulers.get_lr_scheduler_state())

self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval_type = (
Expand Down
27 changes: 16 additions & 11 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def linear_warmup_linear_decay(
return curr_adjustment


class SchedulersContainer:
class SchedulersContainer(Stateful):
"""Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages"""

def __init__(self, optimizers, lr_lambda) -> None:
Expand All @@ -179,16 +179,21 @@ def step(self) -> None:
for scheduler in self.schedulers:
scheduler.step()

def get_lr_scheduler_state(self) -> Dict[str, Any]:
state_dict = {}
if len(self.schedulers) == 1:
state_dict["lr_scheduler"] = self.schedulers[0]
else:
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks
for idx, lr_scheduler in enumerate(self.schedulers):
state_dict[f"lr_scheduler_{idx}"] = lr_scheduler
return state_dict
def state_dict(self) -> Dict[str, Any]:
# Currently, we have one scheduler per optimizer. However, when using MultiSchedule PP or optimizer-in-backward,
# there are multiple optimizers and schedulers, but the scheduler state_dict remains the same for all.
# Therefore, we only save the first one and later load it for all.
assert (
len(self.schedulers) > 0
), "Must have at least one scheduler to save state_dict"
return self.schedulers[0].state_dict()

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# Load the same state_dict for all schedulers. The key value we're concerned with in scheduler.state_dict() is `last_epoch`,
# which is an integer that will be automatically copied. As long as `training.steps` and `training.warmup_steps` remain
# unchanged when resuming from a checkpoint, this approach is safe. We call `.copy()` here to ensure extra safety.
for scheduler in self.schedulers:
scheduler.load_state_dict(state_dict.copy())


def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer:
Expand Down