diff --git a/examples/llm/megatron_gpt_pretraining.py b/examples/llm/megatron_gpt_pretraining.py index 180561d03bac..d3d049e4296e 100644 --- a/examples/llm/megatron_gpt_pretraining.py +++ b/examples/llm/megatron_gpt_pretraining.py @@ -65,7 +65,6 @@ def get_args(): checkpoint_callback = ModelCheckpoint( every_n_train_steps=5000, enable_nemo_ckpt_io=False, - async_save=False, ) callbacks = [checkpoint_callback] diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index ed8ac25185f3..eee3850dfb37 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -51,14 +51,12 @@ def __init__( save_best_model: bool = False, save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation enable_nemo_ckpt_io: bool = True, - async_save: bool = False, try_restore_best_ckpt: bool = True, **kwargs, ): self.save_best_model = save_best_model self.previous_best_path = "" self.enable_nemo_ckpt_io = enable_nemo_ckpt_io - self.async_save = async_save # Checkpoints which removal is deferred until async save is done. # Each element of `deferred_ckpts_to_remove` is a growing list # that `self._remove_checkpoint` adds to. Once `self._save_checkpoint` @@ -221,7 +219,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: super().load_state_dict(state_dict) self._remove_invalid_entries_from_topk() - def setup(self, *args, **kwargs) -> None: + def setup(self, trainer, *args, **kwargs) -> None: from nemo.utils.get_rank import is_global_rank_zero if is_global_rank_zero(): @@ -230,7 +228,9 @@ def setup(self, *args, **kwargs) -> None: # Ensure that all ranks continue with unfinished checkpoints removed if torch.distributed.is_initialized(): torch.distributed.barrier() - super().setup(*args, **kwargs) + + self.async_save = getattr(trainer.strategy, "async_save", False) + super().setup(trainer, *args, **kwargs) def on_save_checkpoint(self, trainer, pl_module, checkpoint): output = super().on_save_checkpoint(trainer, pl_module, checkpoint) diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index 2219324f6b67..9adfb7801f2f 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -105,6 +105,7 @@ def __init__( lazy_init: bool = False, pipeline_dtype: Optional[torch.dtype] = None, save_ckpt_format='torch_dist', + ckpt_async_save=False, ckpt_torch_dist_multiproc=None, ## TODO(ashors): put elsewhere? ckpt_assume_constant_structure=False, ckpt_parallel_save=True, @@ -142,6 +143,7 @@ def __init__( self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0))) self.save_ckpt_format = save_ckpt_format + self.async_save = ckpt_async_save self.torch_dist_multiproc = ckpt_torch_dist_multiproc self.assume_constant_structure = ckpt_assume_constant_structure self.parallel_save = ckpt_parallel_save @@ -253,6 +255,16 @@ def setup(self, trainer: pl.Trainer) -> None: assert self.model is not None _sync_module_states(self.model) + ## add AsyncFinalizerCallback if using async + if self.async_save: + have_async_callback = False + for callback in self.trainer.callbacks: + if isinstance(callback, AsyncFinalizerCallback): + have_async_callback = True + break + if not have_async_callback: + self.trainer.callbacks.append(AsyncFinalizerCallback()) + @override def setup_distributed(self) -> None: self._setup_parallel_ranks() @@ -577,11 +589,9 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr @override def checkpoint_io(self) -> CheckpointIO: if self._checkpoint_io is None: - checkpoint_callback = self.trainer.checkpoint_callback - async_save = getattr(checkpoint_callback, "async_save", False) self._checkpoint_io = MegatronCheckpointIO( save_ckpt_format=self.save_ckpt_format, - async_save=async_save, + async_save=self.async_save, torch_dist_multiproc=self.torch_dist_multiproc, assume_constant_structure=self.assume_constant_structure, parallel_save=self.parallel_save, @@ -589,15 +599,8 @@ def checkpoint_io(self) -> CheckpointIO: parallel_load=self.parallel_load, load_directly_on_device=self.load_directly_on_device, ) - if async_save: + if self.async_save: self._checkpoint_io = AsyncFinalizableCheckpointIO(self._checkpoint_io) - have_async_callback = False - for callback in self.trainer.callbacks: - if isinstance(callback, AsyncFinalizerCallback): - have_async_callback = True - break - if not have_async_callback: - self.trainer.callbacks.append(AsyncFinalizerCallback()) elif isinstance(self._checkpoint_io, _WrappingCheckpointIO): self._checkpoint_io.checkpoint_io = MegatronCheckpointIO()