Skip to content

Commit

Permalink
[NeMo-UX] Set async_save from strategy rather than ModelCheckpoint (#…
Browse files Browse the repository at this point in the history
…9800)

* set async_save from strategy to make checkpoint_io more robust

Signed-off-by: ashors1 <ashors@nvidia.com>

* fix 2.0 test

Signed-off-by: ashors1 <ashors@nvidia.com>

---------

Signed-off-by: ashors1 <ashors@nvidia.com>
  • Loading branch information
ashors1 authored Jul 23, 2024
1 parent b7a494e commit b901138
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
1 change: 0 additions & 1 deletion examples/llm/megatron_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def get_args():
checkpoint_callback = ModelCheckpoint(
every_n_train_steps=5000,
enable_nemo_ckpt_io=False,
async_save=False,
)
callbacks = [checkpoint_callback]

Expand Down
8 changes: 4 additions & 4 deletions nemo/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,12 @@ def __init__(
save_best_model: bool = False,
save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation
enable_nemo_ckpt_io: bool = True,
async_save: bool = False,
try_restore_best_ckpt: bool = True,
**kwargs,
):
self.save_best_model = save_best_model
self.previous_best_path = ""
self.enable_nemo_ckpt_io = enable_nemo_ckpt_io
self.async_save = async_save
# Checkpoints which removal is deferred until async save is done.
# Each element of `deferred_ckpts_to_remove` is a growing list
# that `self._remove_checkpoint` adds to. Once `self._save_checkpoint`
Expand Down Expand Up @@ -221,7 +219,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
super().load_state_dict(state_dict)
self._remove_invalid_entries_from_topk()

def setup(self, *args, **kwargs) -> None:
def setup(self, trainer, *args, **kwargs) -> None:
from nemo.utils.get_rank import is_global_rank_zero

if is_global_rank_zero():
Expand All @@ -230,7 +228,9 @@ def setup(self, *args, **kwargs) -> None:
# Ensure that all ranks continue with unfinished checkpoints removed
if torch.distributed.is_initialized():
torch.distributed.barrier()
super().setup(*args, **kwargs)

self.async_save = getattr(trainer.strategy, "async_save", False)
super().setup(trainer, *args, **kwargs)

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
output = super().on_save_checkpoint(trainer, pl_module, checkpoint)
Expand Down
25 changes: 14 additions & 11 deletions nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
lazy_init: bool = False,
pipeline_dtype: Optional[torch.dtype] = None,
save_ckpt_format='torch_dist',
ckpt_async_save=False,
ckpt_torch_dist_multiproc=None, ## TODO(ashors): put elsewhere?
ckpt_assume_constant_structure=False,
ckpt_parallel_save=True,
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(
self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0)))

self.save_ckpt_format = save_ckpt_format
self.async_save = ckpt_async_save
self.torch_dist_multiproc = ckpt_torch_dist_multiproc
self.assume_constant_structure = ckpt_assume_constant_structure
self.parallel_save = ckpt_parallel_save
Expand Down Expand Up @@ -253,6 +255,16 @@ def setup(self, trainer: pl.Trainer) -> None:
assert self.model is not None
_sync_module_states(self.model)

## add AsyncFinalizerCallback if using async
if self.async_save:
have_async_callback = False
for callback in self.trainer.callbacks:
if isinstance(callback, AsyncFinalizerCallback):
have_async_callback = True
break
if not have_async_callback:
self.trainer.callbacks.append(AsyncFinalizerCallback())

@override
def setup_distributed(self) -> None:
self._setup_parallel_ranks()
Expand Down Expand Up @@ -577,27 +589,18 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr
@override
def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None:
checkpoint_callback = self.trainer.checkpoint_callback
async_save = getattr(checkpoint_callback, "async_save", False)
self._checkpoint_io = MegatronCheckpointIO(
save_ckpt_format=self.save_ckpt_format,
async_save=async_save,
async_save=self.async_save,
torch_dist_multiproc=self.torch_dist_multiproc,
assume_constant_structure=self.assume_constant_structure,
parallel_save=self.parallel_save,
parallel_save_within_dp=self.parallel_save_within_dp,
parallel_load=self.parallel_load,
load_directly_on_device=self.load_directly_on_device,
)
if async_save:
if self.async_save:
self._checkpoint_io = AsyncFinalizableCheckpointIO(self._checkpoint_io)
have_async_callback = False
for callback in self.trainer.callbacks:
if isinstance(callback, AsyncFinalizerCallback):
have_async_callback = True
break
if not have_async_callback:
self.trainer.callbacks.append(AsyncFinalizerCallback())
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
self._checkpoint_io.checkpoint_io = MegatronCheckpointIO()

Expand Down

0 comments on commit b901138

Please sign in to comment.