Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable optional checkpoint at loading #819

Merged
merged 12 commits into from
Feb 7, 2025
9 changes: 9 additions & 0 deletions docs/checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ Finally, once you have obtained the last checkpoint, you can use the following c
python -m torch.distributed.checkpoint.format_utils dcp_to_torch torchtitan/outputs/checkpoint/step-1000 checkpoint.pt
```

7. EXCLUDING SPECIFIC KEYS FROM CHECKPOINT LOADING
In some cases, you may want to partially load from a previous-trained checkpoint and modify certain settings, such as the number of GPUs or the current step. To achieve this, you can use the `exclude_from_loading` parameter to specify which keys should be excluded from loading.
This parameter takes a comma-separated list of keys that should be excluded from loading.
```
[checkpoint]
enable_checkpoint = true
exclude_from_loading = "data_loader,lr_scheduler"
```

That's it. You have now successfully converted a sharded torchtitan checkpoint for use in torchtune.


Expand Down
18 changes: 18 additions & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,24 @@ def build_test_list():
"fsdp_reshard_always",
ngpu=2,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--training.steps 10",
],
# Save at [dp:4] and load at [dp:2, tp:2]. Note that the dataloader should be
# excluded during loading to avoid errors caused by mismatched dp_degree.
[
"--checkpoint.enable_checkpoint",
"--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer",
"--training.tensor_parallel_degree 2",
"--training.steps 20",
],
],
"Optional checkpoint",
"optional_checkpoint",
),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add integration test here, especially for that optional checkpoint at dataloader could avoid dp_degree mismatch error before and after checkpoint

]
return integration_tests_flavors

Expand Down
71 changes: 71 additions & 0 deletions tests/unit_tests/test_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,77 @@ def test_parse_pp_split_points(self):
config.experimental.pipeline_parallel_split_points == cmdline_splits
), config.experimental.pipeline_parallel_split_points

def test_parse_exclude_from_loading(self):

toml_splits = ["optimizer", "dataloader"]
toml_split_str = ",".join(toml_splits)
cmdline_splits = ["optimizer", "lr_scheduler"]
cmdline_split_str = ",".join(cmdline_splits)
# no split points specified
config = JobConfig()
config.parse_args(
[
"--job.config_file",
"./train_configs/debug_model.toml",
]
)
assert config.checkpoint.exclude_from_loading == []

# toml has no split points, but cmdline splits are specified
config = JobConfig()
config.parse_args(
[
"--job.config_file",
"./train_configs/debug_model.toml",
"--checkpoint.exclude_from_loading",
f"{cmdline_split_str}",
]
)
assert (
config.checkpoint.exclude_from_loading == cmdline_splits
), config.checkpoint.exclude_from_loading

# toml has split points, cmdline does not
with tempfile.NamedTemporaryFile() as fp:
with open(fp.name, "wb") as f:
tomli_w.dump(
{
"checkpoint": {
"exclude_from_loading": toml_split_str,
}
},
f,
)
config = JobConfig()
config.parse_args(["--job.config_file", fp.name])
assert (
config.checkpoint.exclude_from_loading == toml_splits
), config.checkpoint.exclude_from_loading

# toml has split points, cmdline overrides them
with tempfile.NamedTemporaryFile() as fp:
with open(fp.name, "wb") as f:
tomli_w.dump(
{
"checkpoint": {
"exclude_from_loading": toml_split_str,
}
},
f,
)
config = JobConfig()
config.parse_args(
[
"--job.config_file",
fp.name,
"--checkpoint.exclude_from_loading",
f"{cmdline_split_str}",
]
)
assert (
config.checkpoint.exclude_from_loading == cmdline_splits
), config.checkpoint.exclude_from_loading

def test_print_help(self):
config = JobConfig()
parser = config.parser
Expand Down
17 changes: 11 additions & 6 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,8 @@ def __init__(
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening
support described in (1).

3. LR schedulers also index model states like optimizers and would need to be flattened properly to support
Copy link
Contributor Author

@mori360 mori360 Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lr_scheduler flatten at #794

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a comment here to say the lr_scheduler resharding assumes that all lr_schedulers are the same.

resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like
optimizers do, so it's hard to write a generic 'flattener' utility.

TODO: This is currently unsolved and needs a fix.
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers with the assumption that
all lr_schedulers have the same state_dict.
"""
self.states = states

Expand Down Expand Up @@ -203,6 +200,7 @@ def __init__(

self.model_weights_only = ckpt_config.model_weights_only
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
self.exclude_from_loading = ckpt_config.exclude_from_loading

self.mp = None
if async_mode == AsyncMode.DISABLED:
Expand Down Expand Up @@ -435,10 +433,17 @@ def load(self, step: int = -1) -> bool:
}
logger.info(f"Loading the checkpoint at step {step}.")
begin = time.monotonic()
states_to_load = {
k: v for k, v in states.items() if k not in self.exclude_from_loading
}
for exclude_key in self.exclude_from_loading:
if exclude_key not in states:
raise ValueError(f"{exclude_key} not found in state_dict.")
dcp.load(
states,
states_to_load,
checkpoint_id=self._create_checkpoint_id(step),
)
states.update(states_to_load)
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
)
Expand Down
23 changes: 22 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


def string_list(raw_arg):
return raw_arg.split(",")
return [s.strip() for s in raw_arg.split(",") if s.strip()]


class JobConfig:
Expand Down Expand Up @@ -529,6 +529,17 @@ def __init__(self):
default=-1,
help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.",
)
self.parser.add_argument(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently checkpoint.exclude only support excluding at loading, shall we use argument like exclude_from_loading?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, exclude_from_loading is more explicit.

"--checkpoint.exclude_from_loading",
type=string_list,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we still do .strip and empty check in string_list?

nargs="*",
default=[],
help="""
Exclude specific keys from being loaded from the checkpoint.
Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'.
This will load the model only, excluding the specified keys.
""",
)
# activation checkpointing configs
self.parser.add_argument(
"--activation_checkpoint.mode",
Expand Down Expand Up @@ -636,6 +647,13 @@ def parse_args(self, args_list: list = sys.argv[1:]):
exp["pipeline_parallel_split_points"] = string_list(
exp["pipeline_parallel_split_points"]
)
if (
"checkpoint" in args_dict
and "exclude_from_loading" in args_dict["checkpoint"]
and isinstance(args_dict["checkpoint"]["exclude_from_loading"], str)
):
ckpt = args_dict["checkpoint"]
ckpt["exclude_from_loading"] = string_list(ckpt["exclude_from_loading"])

# override args dict with cmd_args
cmd_args_dict = self._args_to_two_level_dict(cmd_args)
Expand Down Expand Up @@ -683,6 +701,9 @@ def parse_args_from_command_line(
# since the inferred type is just 'list' and it ends up flattening
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
aux_parser.add_argument("--" + arg, type=string_list)
elif arg == "checkpoint.exclude_from_loading":
# similar to the case above
aux_parser.add_argument("--" + arg, type=string_list)
else:
aux_parser.add_argument("--" + arg, type=type(val))

Expand Down