-
Notifications
You must be signed in to change notification settings - Fork 273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable optional checkpoint at loading #819
Changes from all commits
58466d5
759c545
47f914a
673013b
7418f60
a5c0006
8e31858
2fb6f55
c3d2370
096d506
b1f1d5d
582fe7d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -170,11 +170,8 @@ def __init__( | |
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening | ||
support described in (1). | ||
|
||
3. LR schedulers also index model states like optimizers and would need to be flattened properly to support | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lr_scheduler flatten at #794 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add a comment here to say the |
||
resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like | ||
optimizers do, so it's hard to write a generic 'flattener' utility. | ||
|
||
TODO: This is currently unsolved and needs a fix. | ||
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers with the assumption that | ||
all lr_schedulers have the same state_dict. | ||
""" | ||
self.states = states | ||
|
||
|
@@ -203,6 +200,7 @@ def __init__( | |
|
||
self.model_weights_only = ckpt_config.model_weights_only | ||
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype] | ||
self.exclude_from_loading = ckpt_config.exclude_from_loading | ||
|
||
self.mp = None | ||
if async_mode == AsyncMode.DISABLED: | ||
|
@@ -435,10 +433,17 @@ def load(self, step: int = -1) -> bool: | |
} | ||
logger.info(f"Loading the checkpoint at step {step}.") | ||
begin = time.monotonic() | ||
states_to_load = { | ||
k: v for k, v in states.items() if k not in self.exclude_from_loading | ||
} | ||
for exclude_key in self.exclude_from_loading: | ||
if exclude_key not in states: | ||
raise ValueError(f"{exclude_key} not found in state_dict.") | ||
dcp.load( | ||
states, | ||
states_to_load, | ||
checkpoint_id=self._create_checkpoint_id(step), | ||
) | ||
states.update(states_to_load) | ||
logger.info( | ||
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ | |
|
||
|
||
def string_list(raw_arg): | ||
return raw_arg.split(",") | ||
return [s.strip() for s in raw_arg.split(",") if s.strip()] | ||
|
||
|
||
class JobConfig: | ||
|
@@ -529,6 +529,17 @@ def __init__(self): | |
default=-1, | ||
help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.", | ||
) | ||
self.parser.add_argument( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. currently checkpoint.exclude only support excluding at loading, shall we use argument like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, exclude_from_loading is more explicit. |
||
"--checkpoint.exclude_from_loading", | ||
type=string_list, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we still do |
||
nargs="*", | ||
default=[], | ||
help=""" | ||
Exclude specific keys from being loaded from the checkpoint. | ||
Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'. | ||
This will load the model only, excluding the specified keys. | ||
""", | ||
) | ||
# activation checkpointing configs | ||
self.parser.add_argument( | ||
"--activation_checkpoint.mode", | ||
|
@@ -636,6 +647,13 @@ def parse_args(self, args_list: list = sys.argv[1:]): | |
exp["pipeline_parallel_split_points"] = string_list( | ||
exp["pipeline_parallel_split_points"] | ||
) | ||
if ( | ||
"checkpoint" in args_dict | ||
and "exclude_from_loading" in args_dict["checkpoint"] | ||
and isinstance(args_dict["checkpoint"]["exclude_from_loading"], str) | ||
): | ||
ckpt = args_dict["checkpoint"] | ||
ckpt["exclude_from_loading"] = string_list(ckpt["exclude_from_loading"]) | ||
|
||
# override args dict with cmd_args | ||
cmd_args_dict = self._args_to_two_level_dict(cmd_args) | ||
|
@@ -683,6 +701,9 @@ def parse_args_from_command_line( | |
# since the inferred type is just 'list' and it ends up flattening | ||
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...] | ||
aux_parser.add_argument("--" + arg, type=string_list) | ||
elif arg == "checkpoint.exclude_from_loading": | ||
# similar to the case above | ||
aux_parser.add_argument("--" + arg, type=string_list) | ||
else: | ||
aux_parser.add_argument("--" + arg, type=type(val)) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add integration test here, especially for that optional checkpoint at dataloader could avoid dp_degree mismatch error before and after checkpoint