Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove .nemo instead of renaming #9281

Merged
merged 10 commits into from
May 23, 2024
19 changes: 6 additions & 13 deletions nemo/utils/callbacks/nemo_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
logging.warning(f'always_save_nemo will slow down training for model_parallel > 1.')
# since we are creating tarfile artifacts we need to update .nemo path
self._backup_existing_nemo_ckpt(trainer)
self._maybe_remove_existing_nemo_ckpt(trainer)
app_state.model_restore_path = self._format_nemo_checkpoint_name()
if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
maybe_injected_best_model_path = inject_model_parallel_rank(self.best_model_path)
Expand Down Expand Up @@ -268,10 +268,10 @@ def on_train_end(self, trainer, pl_module):
trainer._checkpoint_connector.restore(self.best_model_path)

if self.save_nemo_on_train_end:
self._backup_existing_nemo_ckpt(trainer)
self._maybe_remove_existing_nemo_ckpt(trainer)
pl_module.save_to(save_path=self._format_nemo_checkpoint_name())

def _backup_existing_nemo_ckpt(self, trainer) -> str:
def _maybe_remove_existing_nemo_ckpt(self, trainer):
"""Search for an available name with version infix and rename existing checkpoint.

NOTE: this behavior is slightly different from regular checkpoints.
Expand All @@ -280,18 +280,11 @@ def _backup_existing_nemo_ckpt(self, trainer) -> str:
and create a backup under the first available name.
"""
base_path = self._format_nemo_checkpoint_name()
available_path = base_path
if self._enable_version_counter:
version_cnt = self.STARTING_VERSION
while self.file_exists(available_path, trainer, check_dist_ckpt=False):
available_path = self._format_nemo_checkpoint_name(version_cnt)
version_cnt += 1
if available_path != base_path:
if self.file_exists(base_path, trainer, check_dist_ckpt=False):
if trainer.is_global_zero:
logging.info(f'{base_path} already exists, moving existing checkpoint to {available_path}')
shutil.move(base_path, available_path)
logging.info(f'removing existing .nemo checkpoint {base_path}')
shutil.rmtree(base_path, ignore_errors=True)
trainer.strategy.barrier()
return available_path

def _format_nemo_checkpoint_name(self, ver: Optional[int] = None) -> str:
version_infix = '' if ver is None else f'{self.CHECKPOINT_JOIN_CHAR}v{ver}'
Expand Down
Loading