Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option to overwrite the existing checkpoint file #17320

Merged
merged 14 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334))


- Enabled optional file versioning of model checkpoints ([#17320](hhttps://github.com/Lightning-AI/lightning/pull/17320))


- Added the process group timeout argument `FSDPStrategy(timeout=...)` for the FSDP strategy ([#17274](https://github.com/Lightning-AI/lightning/pull/17274))


Expand Down
25 changes: 16 additions & 9 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class ModelCheckpoint(Checkpoint):
Please note that the monitors are checked every ``every_n_epochs`` epochs.
if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with ``v1``.
appended with a version count starting with ``v1``
unless ``enable_version_counter`` is set to False.
mode: one of {min, max}.
If ``save_top_k != 0``, the decision to overwrite the current save file is made
based on either the maximization or the minimization of the monitored quantity.
Expand Down Expand Up @@ -129,6 +130,8 @@ class ModelCheckpoint(Checkpoint):
where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch.
If this is ``False``, then the check runs at the end of the validation.
enable_version_counter: Whether to append a version to the existing file name.
If this is ``False``, then the checkpoint files will be overwritten.

Note:
For extra customization, ModelCheckpoint includes the following attributes:
Expand Down Expand Up @@ -217,6 +220,7 @@ def __init__(
train_time_interval: Optional[timedelta] = None,
every_n_epochs: Optional[int] = None,
save_on_train_epoch_end: Optional[bool] = None,
enable_version_counter: bool = True,
):
super().__init__()
self.monitor = monitor
Expand All @@ -226,6 +230,7 @@ def __init__(
self.save_weights_only = save_weights_only
self.auto_insert_metric_name = auto_insert_metric_name
self._save_on_train_epoch_end = save_on_train_epoch_end
self._enable_version_counter = enable_version_counter
self._last_global_step_saved = 0 # no need to save when no steps were taken
self._last_time_checked: Optional[float] = None
self.current_score: Optional[Tensor] = None
Expand Down Expand Up @@ -617,10 +622,11 @@ def _get_metric_interpolated_filepath_name(
) -> str:
filepath = self.format_checkpoint_name(monitor_candidates)

version_cnt = self.STARTING_VERSION
while self.file_exists(filepath, trainer) and filepath != del_filepath:
filepath = self.format_checkpoint_name(monitor_candidates, ver=version_cnt)
version_cnt += 1
if self._enable_version_counter:
version_cnt = self.STARTING_VERSION
while self.file_exists(filepath, trainer) and filepath != del_filepath:
filepath = self.format_checkpoint_name(monitor_candidates, ver=version_cnt)
version_cnt += 1

return filepath

Expand All @@ -640,10 +646,11 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[

filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST)

version_cnt = self.STARTING_VERSION
while self.file_exists(filepath, trainer) and filepath != self.last_model_path:
filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST, ver=version_cnt)
version_cnt += 1
if self._enable_version_counter:
version_cnt = self.STARTING_VERSION
while self.file_exists(filepath, trainer) and filepath != self.last_model_path:
filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST, ver=version_cnt)
version_cnt += 1

# set the last model path before saving because it will be part of the state.
previous, self.last_model_path = self.last_model_path, filepath
Expand Down
34 changes: 34 additions & 0 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,40 @@ def test_ckpt_version_after_rerun_same_trainer(tmpdir):
assert set(os.listdir(tmpdir)) == expected


def test_ckpt_version_counter_disabled_after_rerun_new_trainer(tmpdir):
"""Check that previous checkpoints get overwritten and no suffixes are generated when new trainer instances are
used."""
epochs = 2
for i in range(epochs):
mc = ModelCheckpoint(
dirpath=tmpdir,
save_top_k=-1,
save_last=True,
monitor="epoch",
filename="{epoch}",
enable_version_counter=False,
)
trainer = Trainer(
max_epochs=epochs,
limit_train_batches=1,
limit_val_batches=1,
default_root_dir=tmpdir,
callbacks=[mc],
logger=False,
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.fit(BoringModel())

# check best_k_models and last state
assert {Path(f).name for f in mc.best_k_models} == {"epoch=0.ckpt", "epoch=1.ckpt"}
assert Path(mc.last_model_path).name == "last.ckpt"

# check created ckpts
actual = {f.basename for f in tmpdir.listdir()}
assert actual == {"epoch=0.ckpt", "epoch=1.ckpt", "last.ckpt"}


def test_model_checkpoint_mode_options():
with pytest.raises(MisconfigurationException, match="`mode` can be .* but got unknown_option"):
ModelCheckpoint(mode="unknown_option")
Expand Down