Skip to content

Commit

Permalink
Save ModelCheckpoint's last.ckpt as symlink if possible (#18748)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
awaelchli and carmocca authored Oct 11, 2023
1 parent 7434c47 commit c5e3c45
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 30 deletions.
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `ModelCheckpoint` no longer deletes files under the save-top-k mechanism when resuming from a folder that is not the same as the current checkpoint folder ([#18750](https://github.com/Lightning-AI/lightning/pull/18750))
- The `ModelCheckpoint` no longer deletes the file that was passed to `Trainer.fit(ckpt_path=...)` ([#18750](https://github.com/Lightning-AI/lightning/pull/18750))
- Calling `trainer.fit()` twice now raises an error with strategies that spawn subprocesses through `multiprocessing` (ddp_spawn, xla) ([#18776](https://github.com/Lightning-AI/lightning/pull/18776))
- The `ModelCheckpoint` now saves a symbolic link if `save_last=True` and `save_top_k != 0` ([#18748](https://github.com/Lightning-AI/lightning/pull/18748))

### Deprecated

Expand Down
37 changes: 22 additions & 15 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ class ModelCheckpoint(Checkpoint):
the number of finished epoch and optimizer steps respectively.
monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch.
verbose: verbosity mode. Default: ``False``.
save_last: When ``True``, saves an exact copy of the checkpoint to a file `last.ckpt` whenever a checkpoint
file gets saved. This allows accessing the latest checkpoint in a deterministic manner. Default: ``None``.
save_last: When ``True``, saves a `last.ckpt` whenever a checkpoint file gets saved. On a local filesystem,
this will be a symbolic link, and otherwise a copy of the checkpoint file. This allows accessing the latest
checkpoint in a deterministic manner. Default: ``None``.
save_top_k: if ``save_top_k == k``,
the best k models according to the quantity monitored will be saved.
if ``save_top_k == 0``, no models are saved.
Expand Down Expand Up @@ -241,6 +242,7 @@ def __init__(
self.best_model_score: Optional[Tensor] = None
self.best_model_path = ""
self.last_model_path = ""
self._last_checkpoint_saved = ""

self.kth_value: Tensor
self.dirpath: Optional[_PATH]
Expand Down Expand Up @@ -371,12 +373,21 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
trainer.save_checkpoint(filepath, self.save_weights_only)

self._last_global_step_saved = trainer.global_step
self._last_checkpoint_saved = filepath

# notify loggers
if trainer.is_global_zero:
for logger in trainer.loggers:
logger.after_save_checkpoint(proxy(self))

@staticmethod
def _link_checkpoint(trainer: "pl.Trainer", filepath: str, linkpath: str) -> None:
if trainer.is_global_zero:
if os.path.lexists(linkpath):
os.remove(linkpath)
os.symlink(filepath, linkpath)
trainer.strategy.barrier()

def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
from lightning.pytorch.trainer.states import TrainerFn

Expand Down Expand Up @@ -427,19 +438,12 @@ def __validate_init_configuration(self) -> None:
"should be mutually exclusive."
)

if self.monitor is None:
if self.monitor is None and self.save_top_k not in (-1, 0, 1):
# -1: save all epochs, 0: nothing is saved, 1: save last epoch
if self.save_top_k not in (-1, 0, 1):
raise MisconfigurationException(
f"ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid"
" configuration. No quantity for top_k to track."
)

if self.save_top_k == -1 and self.save_last:
rank_zero_info(
"ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)"
" will duplicate the last checkpoint saved."
)
raise MisconfigurationException(
f"ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid"
" configuration. No quantity for top_k to track."
)

def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None:
self._fs = get_filesystem(dirpath if dirpath else "")
Expand Down Expand Up @@ -662,7 +666,10 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[

# set the last model path before saving because it will be part of the state.
previous, self.last_model_path = self.last_model_path, filepath
self._save_checkpoint(trainer, filepath)
if self._fs.protocol == "file" and self._last_checkpoint_saved and self.save_top_k != 0:
self._link_checkpoint(trainer, self._last_checkpoint_saved, filepath)
else:
self._save_checkpoint(trainer, filepath)
if previous and self._should_remove_checkpoint(trainer, previous, filepath):
self._remove_checkpoint(trainer, previous)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self):
self.last_coeff = 10.0

def training_step(self, batch, batch_idx):
loss = self.step(torch.ones(32))
loss = self.step(torch.ones(32, device=self.device))
loss = loss / (loss + 0.0000001)
loss += self.last_coeff
self.log("my_loss", loss)
Expand All @@ -80,8 +80,7 @@ def training_step(self, batch, batch_idx):
trainer.fit(model)

if save_last:
# last epochs are saved every step (so double the save calls)
expected = expected * 2
expected = expected
assert save_mock.call_count == expected


Expand Down
19 changes: 11 additions & 8 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import time
from argparse import Namespace
from datetime import timedelta
from logging import INFO
from pathlib import Path
from typing import Union
from unittest import mock
Expand Down Expand Up @@ -510,7 +509,8 @@ def test_model_checkpoint_save_last(tmpdir):
assert set(os.listdir(tmpdir)) == set(
[f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [10, 20, 30])] + [last_filename]
)

assert os.path.islink(tmpdir / last_filename)
assert os.path.realpath(tmpdir / last_filename) == model_checkpoint._last_checkpoint_saved
ModelCheckpoint.CHECKPOINT_NAME_LAST = "last"


Expand Down Expand Up @@ -589,10 +589,7 @@ def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):
max_epochs=epochs,
logger=False,
)

with caplog.at_level(INFO):
trainer.fit(model)
assert "will duplicate the last checkpoint saved" in caplog.text
trainer.fit(model)

# these should not be set if monitor is None
assert checkpoint_callback.monitor is None
Expand All @@ -606,6 +603,7 @@ def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):
expected = [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [10, 20])]
expected.append("last.ckpt")
assert set(os.listdir(tmpdir)) == set(expected)
assert os.path.islink(tmpdir / "last.ckpt")


@pytest.mark.parametrize("every_n_epochs", list(range(4)))
Expand Down Expand Up @@ -709,6 +707,8 @@ def test_model_checkpoint_topk_zero(tmpdir):
# check that only the last ckpt was created
assert os.listdir(tmpdir) == ["last.ckpt"]
assert checkpoint_callback.last_model_path == tmpdir / "last.ckpt"
# 'last.ckpt' is not a symlink because there are no top-k checkpoints to link
assert not os.path.islink(checkpoint_callback.last_model_path)


def test_model_checkpoint_topk_all(tmpdir):
Expand Down Expand Up @@ -814,6 +814,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
path_last = str(tmpdir / "last.ckpt")
assert path_last == model_checkpoint.last_model_path
assert os.path.isfile(path_last_epoch)
assert os.path.islink(path_last)

ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(path_last)
Expand Down Expand Up @@ -1343,7 +1344,7 @@ def test_save_last_saves_correct_last_model_path(tmpdir):
trainer = Trainer(callbacks=mc)
trainer.strategy.connect(BoringModel())

mc._save_last_checkpoint(trainer, {"foo": 1})
mc._save_last_checkpoint(trainer, {"foo": torch.tensor(1)})
expected = "foo=1-last.ckpt"
assert os.listdir(tmpdir) == [expected]
full_path = str(tmpdir / expected)
Expand All @@ -1366,6 +1367,8 @@ def test_save_last_versioning(tmpdir):
)
trainer.fit(model)
assert {"last.ckpt", "last-v1.ckpt"} == set(os.listdir(tmpdir))
# 'last.ckpt' is not a symlink since `save_top_k=0` didn't save any other checkpoints to link to
assert all(not os.path.islink(tmpdir / path) for path in set(os.listdir(tmpdir)))


def test_none_monitor_saves_correct_best_model_path(tmpdir):
Expand All @@ -1385,7 +1388,7 @@ def test_last_global_step_saved():
# this should not save anything
model_checkpoint = ModelCheckpoint(save_top_k=0, save_last=False, monitor="foo")
trainer = Mock()
monitor_candidates = {"foo": 123}
monitor_candidates = {"foo": torch.tensor(123)}
model_checkpoint._save_topk_checkpoint(trainer, monitor_candidates)
model_checkpoint._save_last_checkpoint(trainer, monitor_candidates)
assert model_checkpoint._last_global_step_saved == 0
Expand Down
6 changes: 4 additions & 2 deletions tests/tests_pytorch/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,11 @@ def get_trainer_args():
"best_k_models",
"kth_best_model_path",
"kth_value",
"last_model_path",
):
assert getattr(before, attribute) == getattr(after, attribute)
assert getattr(before, attribute) == getattr(after, attribute), f"{attribute}"
# `before.last_model_path` is a symlink pointing to a checkpoint saved before that symlink was created,
# hence reloading that checkpoint will restore `after.last_model_path = ""`
assert after.last_model_path == ""


@RunIf(sklearn=True)
Expand Down
6 changes: 4 additions & 2 deletions tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_checkpoint_plugin_called(tmpdir):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
accelerator="cpu",
strategy=SingleDeviceStrategy("cpu", checkpoint_io=checkpoint_plugin),
callbacks=ck,
max_epochs=2,
Expand All @@ -60,7 +61,7 @@ def test_checkpoint_plugin_called(tmpdir):
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt"}
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2.ckpt"
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last.ckpt"
assert checkpoint_plugin.save_checkpoint.call_count == 4
assert checkpoint_plugin.save_checkpoint.call_count == 2
assert checkpoint_plugin.remove_checkpoint.call_count == 1

trainer.test(model, ckpt_path=ck.last_model_path)
Expand All @@ -72,6 +73,7 @@ def test_checkpoint_plugin_called(tmpdir):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
accelerator="cpu",
strategy=SingleDeviceStrategy("cpu"),
plugins=[checkpoint_plugin],
callbacks=ck,
Expand All @@ -86,7 +88,7 @@ def test_checkpoint_plugin_called(tmpdir):
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt", "epoch=1-step=2-v1.ckpt", "last-v1.ckpt"}
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2-v1.ckpt"
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last-v1.ckpt"
assert checkpoint_plugin.save_checkpoint.call_count == 4
assert checkpoint_plugin.save_checkpoint.call_count == 2
assert checkpoint_plugin.remove_checkpoint.call_count == 1

trainer.test(model, ckpt_path=ck.last_model_path)
Expand Down

0 comments on commit c5e3c45

Please sign in to comment.