From 9358898c6e454cd712dd100f54415b81469d8cde Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 13 Nov 2024 11:51:40 +0100 Subject: [PATCH] Ensure restarting from checkpoints leads to consistent internal counters (#20379) * Fix checkpoint progress for fit loop and batch loop * Check loss parity * Rename test * Fix validation loop handling on restart * Fix loop reset test * Avoid skipping to val end if saved mid validation * Fix type checks in compare state dicts * Fix edge cases and start from last with and without val * Clean up * Formatting * Avoid running validation when restarting from last * Fix type annotations * Fix formatting * Ensure int max_batch * Fix condition on batches that stepped * Remove expected on_train_epoch_start when restarting mid epoch --- .../pytorch/loops/evaluation_loop.py | 38 +- src/lightning/pytorch/loops/fit_loop.py | 107 +++++- src/lightning/pytorch/loops/loop.py | 10 + src/lightning/pytorch/loops/progress.py | 22 ++ .../pytorch/loops/training_epoch_loop.py | 83 +++- tests/tests_pytorch/loops/test_loops.py | 354 +++++++++++++++++- tests/tests_pytorch/models/test_hooks.py | 2 - 7 files changed, 584 insertions(+), 32 deletions(-) diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 0ab3901cf072d..78573c45aab76 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -15,6 +15,7 @@ import shutil import sys from collections import ChainMap, OrderedDict, defaultdict +from dataclasses import dataclass from typing import Any, DefaultDict, Iterable, Iterator, List, Optional, Tuple, Union from lightning_utilities.core.apply_func import apply_to_collection @@ -45,6 +46,12 @@ from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature +@dataclass +class RestartStage: + NONE = "none" + RESTARTED_MID_EVALUATION = "restarted_mid_evaluation" + + class _EvaluationLoop(_Loop): """Top-level loop where validation/testing starts.""" @@ -73,6 +80,7 @@ def __init__( self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int) self._last_val_dl_reload_epoch = float("-inf") self._module_mode = _ModuleMode() + self._restart_stage = RestartStage.NONE @property def num_dataloaders(self) -> int: @@ -137,7 +145,7 @@ def run(self) -> List[_OUT_DICT]: # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support break finally: - self._restarting = False + self.on_iteration_done() self._store_dataloader_outputs() return self.on_run_end() @@ -197,6 +205,24 @@ def setup_data(self) -> None: # this depends on the data used, so reset it too self._seen_batches_per_dataloader = defaultdict(int) + @property + def restarted_mid_evaluation(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_MID_EVALUATION + + def update_restart_stage(self) -> None: + if ( + self.restarting + and self.batch_progress.total.started == self.batch_progress.total.ready + and self.batch_progress.total.processed == self.batch_progress.total.started - 1 + and self.batch_progress.total.completed == self.batch_progress.total.processed + ): + self._restart_stage = RestartStage.RESTARTED_MID_EVALUATION + else: + self._restart_stage = RestartStage.NONE + + def reset_restart_stage(self) -> None: + self._restart_stage = RestartStage.NONE + def reset(self) -> None: """Resets the internal state of the loop.""" trainer = self.trainer @@ -236,6 +262,16 @@ def reset(self) -> None: data_fetcher._stop_profiler = self._on_after_fetch self._data_fetcher = data_fetcher + def increment_progress_to_evaluation_end(self) -> None: + self.setup_data() + if self.skip: + return + self.reset() + max_batch = int(max(self.max_batches)) + if max_batch == -1: + return + self.batch_progress.increment_by(max_batch, True) + def on_run_start(self) -> None: """Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start`` hooks.""" diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index eb30e32757c9a..e20088acd0af3 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union import torch @@ -45,6 +46,15 @@ log = logging.getLogger(__name__) +@dataclass +class RestartStage: + NONE = "none" + RESTARTED_ON_EPOCH_START = "restarted_on_epoch_start" + RESTARTED_MID_EPOCH = "restarted_mid_epoch" + RESTARTED_ON_EPOCH_END = "restarted_on_epoch_end" + RESUMED_ON_EPOCH_END = "resumed_on_epoch_end" + + class _FitLoop(_Loop): """This loop is the top-level loop where training starts. @@ -97,6 +107,7 @@ def __init__( self._combined_loader_states_to_load: List[Dict[str, Any]] = [] self._data_fetcher: Optional[_DataFetcher] = None self._last_train_dl_reload_epoch = float("-inf") + self._restart_stage = RestartStage.NONE @property def total_batch_idx(self) -> int: @@ -204,9 +215,10 @@ def run(self) -> None: self.on_advance_start() self.advance() self.on_advance_end() - self._restarting = False except StopIteration: break + finally: + self.on_iteration_done() self._restarting = False self.on_run_end() @@ -302,14 +314,92 @@ def setup_data(self) -> None: category=PossibleUserWarning, ) + @property + def restarted_on_epoch_start(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_START + + @property + def restarted_mid_epoch(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_MID_EPOCH + + @property + def restarted_on_epoch_end(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_END + + @property + def resumed_on_epoch_end(self) -> bool: + # This case happens when restarting from last without validation at + # the end of epoch. In this case self.restarting is False. + return self._restart_stage == RestartStage.RESUMED_ON_EPOCH_END + + def update_restart_stage(self) -> None: + if ( + self.restarting + and self.epoch_progress.total.started == self.epoch_progress.total.ready - 1 + and self.epoch_progress.total.processed == self.epoch_progress.total.started + and self.epoch_progress.total.completed == self.epoch_progress.total.processed + ): + self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_START + elif ( + self.restarting + and self.epoch_progress.total.started == self.epoch_progress.total.ready + and self.epoch_progress.total.processed == self.epoch_progress.total.started - 1 + and self.epoch_progress.total.completed == self.epoch_progress.total.processed + ): + self._restart_stage = RestartStage.RESTARTED_MID_EPOCH + elif ( + self.restarting + and self.epoch_progress.total.started == self.epoch_progress.total.ready + and self.epoch_progress.total.processed == self.epoch_progress.total.started + and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1 + ): + self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_END + elif ( + self._loaded_from_state_dict + and self.epoch_progress.total.started == self.epoch_progress.total.ready + and self.epoch_progress.total.processed == self.epoch_progress.total.started + and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1 + ): + self._restart_stage = RestartStage.RESUMED_ON_EPOCH_END + else: + self._restart_stage = RestartStage.NONE + + self.epoch_loop.update_restart_stage() + + def reset_restart_stage(self) -> None: + self._restart_stage = RestartStage.NONE + def reset(self) -> None: """Resets the internal state of this loop.""" assert self.trainer.model is not None torch.set_grad_enabled(True) - if self.restarting: + self.update_restart_stage() + + if self.restarted_on_epoch_start: self.epoch_progress.reset_on_restart() + if self.resumed_on_epoch_end: + # when restarting from last without validation at end of epoch, + # self.restarting is False but it's still resuming + self.epoch_progress.increment_completed() + + if ( + self.epoch_loop.restarted_on_train_batch_end + and self.restarted_mid_epoch + and self.epoch_loop.batch_progress.is_last_batch + ): + self.epoch_progress.increment_processed() + self.epoch_progress.increment_completed() + + if ( + self.epoch_loop.restarted_on_train_batch_end + and self.epoch_loop.batch_progress.is_last_batch + and not self.restarted_mid_epoch + and not self.epoch_loop.val_loop.batch_progress.is_last_batch + ): + self.epoch_progress.increment_completed() + def on_run_start(self) -> None: """Calls the ``on_train_start`` hook.""" # update the current_epoch in-case of checkpoint reload @@ -340,12 +430,14 @@ def on_advance_start(self) -> None: for i, dl in enumerate(self._combined_loader.flattened): _set_sampler_epoch(dl, self.epoch_progress.current.processed) - self.epoch_progress.increment_ready() + if not self.restarted_mid_epoch and not self.restarted_on_epoch_end: + if not self.restarted_on_epoch_start: + self.epoch_progress.increment_ready() - call._call_callback_hooks(trainer, "on_train_epoch_start") - call._call_lightning_module_hook(trainer, "on_train_epoch_start") + call._call_callback_hooks(trainer, "on_train_epoch_start") + call._call_lightning_module_hook(trainer, "on_train_epoch_start") - self.epoch_progress.increment_started() + self.epoch_progress.increment_started() def advance(self) -> None: """Runs one whole epoch.""" @@ -379,8 +471,7 @@ def on_advance_end(self) -> None: trainer._logger_connector.on_epoch_end() - if self.epoch_loop._num_ready_batches_reached(): - # if we are restarting and the above condition holds, it's because we are reloading an epoch-end checkpoint. + if not self.restarting and self.epoch_loop._num_ready_batches_reached(): # since metric-based schedulers require access to metrics and those are not currently saved in the # checkpoint, the plateau schedulers shouldn't be updated self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting) diff --git a/src/lightning/pytorch/loops/loop.py b/src/lightning/pytorch/loops/loop.py index 56d520800c447..111377a222b3f 100644 --- a/src/lightning/pytorch/loops/loop.py +++ b/src/lightning/pytorch/loops/loop.py @@ -22,6 +22,7 @@ class _Loop: def __init__(self, trainer: "pl.Trainer") -> None: self._restarting = False + self._loaded_from_state_dict = False self.trainer = trainer @property @@ -37,6 +38,9 @@ def restarting(self, restarting: bool) -> None: if isinstance(loop, _Loop): loop.restarting = restarting + def reset_restart_stage(self) -> None: + pass + def on_save_checkpoint(self) -> Dict: """Called when saving a model checkpoint, use to persist loop state. @@ -82,6 +86,7 @@ def load_state_dict( if isinstance(v, _Loop): v.load_state_dict(state_dict.copy(), prefix + k + ".") self.restarting = True + self._loaded_from_state_dict = True def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None: for k, v in self.__dict__.items(): @@ -93,3 +98,8 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None: v.load_state_dict(state_dict[key]) if prefix + "state_dict" in state_dict: # compatibility with old checkpoints self.on_load_checkpoint(state_dict[prefix + "state_dict"]) + + def on_iteration_done(self) -> None: + self._restarting = False + self._loaded_from_state_dict = False + self.reset_restart_stage() diff --git a/src/lightning/pytorch/loops/progress.py b/src/lightning/pytorch/loops/progress.py index 3d34653122329..6880b24f70c65 100644 --- a/src/lightning/pytorch/loops/progress.py +++ b/src/lightning/pytorch/loops/progress.py @@ -68,6 +68,10 @@ def reset_on_restart(self) -> None: """ self.ready = self.completed + def increment_by(self, n: int) -> None: + self.ready += n + self.completed += n + @dataclass class _StartedTracker(_ReadyCompletedTracker): @@ -94,6 +98,11 @@ def reset_on_restart(self) -> None: super().reset_on_restart() self.started = self.completed + @override + def increment_by(self, n: int) -> None: + super().increment_by(n) + self.started += n + @dataclass class _ProcessedTracker(_StartedTracker): @@ -121,6 +130,11 @@ def reset_on_restart(self) -> None: super().reset_on_restart() self.processed = self.completed + @override + def increment_by(self, n: int) -> None: + super().increment_by(n) + self.processed += n + @dataclass class _Progress(_BaseProgress): @@ -175,6 +189,10 @@ def reset_on_run(self) -> None: def reset_on_restart(self) -> None: self.current.reset_on_restart() + def increment_by(self, n: int) -> None: + self.total.increment_by(n) + self.current.increment_by(n) + @override def load_state_dict(self, state_dict: dict) -> None: self.total.load_state_dict(state_dict["total"]) @@ -206,6 +224,10 @@ def reset_on_run(self) -> None: super().reset_on_run() self.is_last_batch = False + def increment_by(self, n: int, is_last_batch: bool = False) -> None: + super().increment_by(n) + self.is_last_batch = is_last_batch + @override def load_state_dict(self, state_dict: dict) -> None: super().load_state_dict(state_dict) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 9e36ee65176c8..1c749de3a1b6d 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -13,6 +13,7 @@ # limitations under the License. import math from collections import OrderedDict +from dataclasses import dataclass from typing import Any, Dict, Optional, Union from typing_extensions import override @@ -37,6 +38,13 @@ _BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] +@dataclass +class RestartStage: + NONE = "none" + RESTARTED_ON_TRAIN_BATCH_END = "restarted_on_train_batch_end" + RESTARTED_ON_LAST = "restarted_on_last" + + class _TrainingEpochLoop(loops._Loop): """Iterates over all batches in the dataloader (one epoch) that the user returns in their :meth:`~lightning.pytorch.core.LightningModule.train_dataloader` method. @@ -81,6 +89,8 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s self._results = _ResultCollection(training=True) self._warning_cache = WarningCache() self._batches_that_stepped: int = 0 + self._restart_stage = RestartStage.NONE + self._skip_next_val = False @property def total_batch_idx(self) -> int: @@ -139,13 +149,63 @@ def run(self, data_fetcher: _DataFetcher) -> None: try: self.advance(data_fetcher) self.on_advance_end(data_fetcher) - self._restarting = False except StopIteration: break - self._restarting = False + finally: + self.on_iteration_done() + + @property + def restarted_on_train_batch_end(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_ON_TRAIN_BATCH_END + + @property + def restarted_on_last(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_ON_LAST + + def update_restart_stage(self) -> None: + if ( + self.restarting + and self.batch_progress.total.started == self.batch_progress.total.ready + and self.batch_progress.total.processed == self.batch_progress.total.started + and self.batch_progress.total.completed == self.batch_progress.total.processed - 1 + ): + self._restart_stage = RestartStage.RESTARTED_ON_TRAIN_BATCH_END + elif ( + self.restarting + and self.batch_progress.total.started == self.batch_progress.total.ready + and self.batch_progress.total.processed == self.batch_progress.total.started + and self.batch_progress.total.completed == self.batch_progress.total.processed + ): + self._restart_stage = RestartStage.RESTARTED_ON_LAST + else: + self._restart_stage = RestartStage.NONE + + self.val_loop.update_restart_stage() + + def reset_restart_stage(self) -> None: + self._restart_stage = RestartStage.NONE def reset(self) -> None: """Resets the internal state of the loop for a new run.""" + if ( + self.restarting + and not self._should_accumulate() + and (self.restarted_on_train_batch_end or not self.restarted_on_last) + ): + # batches_that_stepped is never set prior to saving a checkpoint, even when saving + # happens on_validation_end + # we could set it in the checkpoint but we prefer to keep checkpoints backward compatible + self._batches_that_stepped += 1 + + if self.restarted_on_train_batch_end: + self.batch_progress.increment_completed() + # handle situation in which save happened on_train_batch_end and epoch is at end + if self.batch_progress.current.completed >= self.trainer.num_training_batches: + self.batch_progress.reset_on_run() + self.scheduler_progress.reset_on_run() + self.automatic_optimization.optim_progress.reset_on_run() + self.val_loop.batch_progress.total.reset() + if self.restarting: self.batch_progress.reset_on_restart() self.scheduler_progress.reset_on_restart() @@ -197,8 +257,18 @@ def advance(self, data_fetcher: _DataFetcher) -> None: """ if self.restarting and self._should_check_val_fx(data_fetcher): - # skip training and run validation in `on_advance_end` - return + if self.val_loop.restarted_mid_evaluation: + # Go back and finish running validation + return + + if self.restarted_on_last: + # Avoid running validation again if we saved on last + self._skip_next_val = True + return + + # fast forward progress counters to end of validation + self.val_loop.increment_progress_to_evaluation_end() + # we are going to train first so the val loop does not need to restart self.val_loop.restarting = False @@ -282,6 +352,11 @@ def on_advance_end(self, data_fetcher: _DataFetcher) -> None: # VALIDATE IF NEEDED # ----------------------------------------- should_check_val = self._should_check_val_fx(data_fetcher) + + if self._skip_next_val: + should_check_val = False + self._skip_next_val = False + if should_check_val: # this needs to be set so the correct `trainer._active_loop` is picked self.trainer.validating = True diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index ff317cd2e18ba..8d94275b5b245 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -14,7 +14,7 @@ import os from copy import deepcopy from dataclasses import dataclass -from typing import Dict, Iterator +from typing import Any, Dict, Iterator from unittest.mock import ANY, Mock import pytest @@ -25,6 +25,7 @@ from lightning.pytorch.loops import _Loop from lightning.pytorch.loops.progress import _BaseProgress from lightning.pytorch.utilities import CombinedLoader +from lightning.pytorch.utilities.types import STEP_OUTPUT from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter from tests_pytorch.helpers.runif import RunIf @@ -396,12 +397,13 @@ def training_step(self, batch, batch_idx): assert state_dict == checkpoint["loops"]["fit_loop"] trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) - # test resetting manually, we expect all `ready` counters to be reset to `completed` + # test resetting manually, we expect the `ready` counter for batch to be reset to `completed` + # but the `ready` counter for epoch to not be reset, since we are still mid epoch trainer.fit_loop.reset() trainer.fit_loop.epoch_loop.reset() epoch_progress = trainer.fit_loop.epoch_progress - assert epoch_progress.current.ready == stop_epoch + assert epoch_progress.current.ready == stop_epoch + 1 assert epoch_progress.current.completed == stop_epoch batch_progress = trainer.fit_loop.epoch_loop.batch_progress @@ -417,7 +419,7 @@ def training_step(self, batch, batch_idx): state_dict = trainer.fit_loop.state_dict() assert state_dict != checkpoint["loops"]["fit_loop"] assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1 - assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch + assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch + 1 def test_loop_state_on_complete_run(tmp_path): @@ -557,23 +559,38 @@ def test_fit_loop_reset(tmp_path): # we load exactly what was saved - no reset yet fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"]) - # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0 + + assert fit_loop.restarting + assert fit_loop.epoch_progress.total.ready == 1 + assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch + assert fit_loop.epoch_progress.current.ready == 1 + assert fit_loop.epoch_progress.current.completed == 0 + + assert epoch_loop.batch_progress.total.ready == 2 + assert epoch_loop.batch_progress.total.processed == 2 + assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end + assert epoch_loop.batch_progress.current.ready == 2 # currents get set to the completed value + assert epoch_loop.batch_progress.current.processed == 2 + assert epoch_loop.batch_progress.current.completed == 1 + fit_loop.reset() epoch_loop.reset() + # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0 assert fit_loop.restarting assert fit_loop.epoch_progress.total.ready == 1 assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch - assert fit_loop.epoch_progress.current.ready == 0 + assert fit_loop.epoch_progress.current.ready == 1 assert fit_loop.epoch_progress.current.completed == 0 + # however it should increment completed batch progress, since it was saved immediately prior assert epoch_loop.restarting assert epoch_loop.batch_progress.total.ready == 2 assert epoch_loop.batch_progress.total.processed == 2 - assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 1 # currents get set to the completed value - assert epoch_loop.batch_progress.current.processed == 1 - assert epoch_loop.batch_progress.current.completed == 1 + assert epoch_loop.batch_progress.total.completed == 2 + assert epoch_loop.batch_progress.current.ready == 2 + assert epoch_loop.batch_progress.current.processed == 2 + assert epoch_loop.batch_progress.current.completed == 2 assert optimizer_loop.restarting @@ -587,23 +604,326 @@ def test_fit_loop_reset(tmp_path): # we load exactly what was saved - no reset yet fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"]) + + assert fit_loop.restarting + assert fit_loop.epoch_progress.total.ready == 1 + assert fit_loop.epoch_progress.total.completed == 0 + assert fit_loop.epoch_progress.current.ready == 1 + assert fit_loop.epoch_progress.current.completed == 0 + # resetting from a end-of-epoch checkpoint SHOULD reset the current counters to 0 fit_loop.reset() epoch_loop.reset() + # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0 + # since we are restarting at the end of epoch, we need to see `completed` being updated after reset assert fit_loop.restarting assert fit_loop.epoch_progress.total.ready == 1 - assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes - assert fit_loop.epoch_progress.current.ready == 0 - assert fit_loop.epoch_progress.current.completed == 0 + assert fit_loop.epoch_progress.total.completed == 1 + assert fit_loop.epoch_progress.current.ready == 1 + assert fit_loop.epoch_progress.current.completed == 1 + # however it should increment completed batch progress, since it was saved immediately prior assert epoch_loop.restarting assert epoch_loop.batch_progress.total.ready == 4 assert epoch_loop.batch_progress.total.processed == 4 - assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 3 # currents get set to the completed value - assert epoch_loop.batch_progress.current.processed == 3 - assert epoch_loop.batch_progress.current.completed == 3 + assert epoch_loop.batch_progress.total.completed == 4 + assert epoch_loop.batch_progress.current.ready == 0 + assert epoch_loop.batch_progress.current.processed == 0 + assert epoch_loop.batch_progress.current.completed == 0 + + +def compare_state_dicts(dict1, dict2): + def compare_leaves(d1, d2): + result = {} + all_keys = set(d1.keys()).union(d2.keys()) + + for key in all_keys: + val1 = d1.get(key, None) + val2 = d2.get(key, None) + + if isinstance(val1, dict) and isinstance(val2, dict): + res = compare_leaves(val1, val2) + if res: + result[key] = res + elif isinstance(val1, dict) or isinstance(val2, dict): + raise ValueError("dicts have different leaves") + elif isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor): + diff = torch.norm(val1 - val2) + if diff > 1e-8: + result[key] = f"{diff} > 1e-8" + elif isinstance(val1, float) and isinstance(val2, float): + if abs(val1 - val2) > 1e-8: + result[key] = f"{val1} != {val2}" + elif val1 != val2: + result[key] = f"{val1} != {val2}" + return result + + return compare_leaves(dict1, dict2) + + +class RangeDataset(torch.utils.data.Dataset): + def __init__(self, size: int, length: int): + self.len = length + data = torch.arange(0, size) / size + self.data = data.unsqueeze(0).repeat(length, 1) + + def __getitem__(self, index: int) -> torch.Tensor: + return self.data[index] + + def __len__(self) -> int: + return self.len + + +class PredictableBoringModel(BoringModel): + def __init__(self) -> None: + super().__init__() + self.last_loss = float("inf") + + def train_dataloader(self) -> DataLoader: + return DataLoader(RangeDataset(32, 64)) + + def val_dataloader(self) -> DataLoader: + return DataLoader(RangeDataset(32, 64)) + + def test_dataloader(self) -> DataLoader: + return DataLoader(RangeDataset(32, 64)) + + def predict_dataloader(self) -> DataLoader: + return DataLoader(RangeDataset(32, 64)) + + def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: + loss = self.step(batch) + self.last_loss = loss + return {"loss": loss} + + +def test_restart_parity(tmp_path): + model = PredictableBoringModel() + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + every_n_train_steps=2, + save_top_k=-1, + ) + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=4, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + trainer.fit(model) + loss = model.last_loss + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=4, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + trainer.fit(model, ckpt_path=str(tmp_path / "epoch=0-step=2.ckpt")) + loss_v1 = model.last_loss + + assert abs(loss - loss_v1) < 1e-8 + + end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=4.ckpt"), weights_only=True) + end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=0-step=4-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(end_of_epoch_ckpt["loops"], end_of_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(end_of_epoch_ckpt["lr_schedulers"][0], end_of_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert end_of_epoch_ckpt["epoch"] == end_of_epoch_ckpt_v1["epoch"] + assert end_of_epoch_ckpt["global_step"] == end_of_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {} + + mid_epoch_ckpt = torch.load(str(tmp_path / "epoch=1-step=6.ckpt"), weights_only=True) + mid_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=1-step=6-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(mid_epoch_ckpt["loops"], mid_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(mid_epoch_ckpt["lr_schedulers"][0], mid_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert mid_epoch_ckpt["epoch"] == mid_epoch_ckpt_v1["epoch"] + assert mid_epoch_ckpt["global_step"] == mid_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(mid_epoch_ckpt["state_dict"], mid_epoch_ckpt_v1["state_dict"]) == {} + + end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=1-step=8.ckpt"), weights_only=True) + end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=1-step=8-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(end_of_epoch_ckpt["loops"], end_of_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(end_of_epoch_ckpt["lr_schedulers"][0], end_of_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert end_of_epoch_ckpt["epoch"] == end_of_epoch_ckpt_v1["epoch"] + assert end_of_epoch_ckpt["global_step"] == end_of_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {} + + +def test_restart_with_val_parity(tmp_path): + model = PredictableBoringModel() + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + every_n_train_steps=2, + save_top_k=-1, + ) + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=4, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + limit_val_batches=4, + val_check_interval=2, + ) + trainer.fit(model) + loss = model.last_loss + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=4, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + limit_val_batches=4, + val_check_interval=2, + ) + trainer.fit(model, ckpt_path=str(tmp_path / "epoch=0-step=2.ckpt")) + loss_v1 = model.last_loss + + assert abs(loss - loss_v1) < 1e-8 + + end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=4.ckpt"), weights_only=True) + end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=0-step=4-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(end_of_epoch_ckpt["loops"], end_of_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(end_of_epoch_ckpt["lr_schedulers"][0], end_of_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert end_of_epoch_ckpt["epoch"] == end_of_epoch_ckpt_v1["epoch"] + assert end_of_epoch_ckpt["global_step"] == end_of_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {} + + mid_epoch_ckpt = torch.load(str(tmp_path / "epoch=1-step=6.ckpt"), weights_only=True) + mid_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=1-step=6-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(mid_epoch_ckpt["loops"], mid_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(mid_epoch_ckpt["lr_schedulers"][0], mid_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert mid_epoch_ckpt["epoch"] == mid_epoch_ckpt_v1["epoch"] + assert mid_epoch_ckpt["global_step"] == mid_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(mid_epoch_ckpt["state_dict"], mid_epoch_ckpt_v1["state_dict"]) == {} + + end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=1-step=8.ckpt"), weights_only=True) + end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=1-step=8-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(end_of_epoch_ckpt["loops"], end_of_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(end_of_epoch_ckpt["lr_schedulers"][0], end_of_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert end_of_epoch_ckpt["epoch"] == end_of_epoch_ckpt_v1["epoch"] + assert end_of_epoch_ckpt["global_step"] == end_of_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {} + + +def test_restart_from_last_parity(tmp_path): + model = PredictableBoringModel() + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + save_last=True, + save_top_k=-1, + ) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + trainer.fit(model) + + last_ckpt_1 = torch.load(str(tmp_path / "last.ckpt"), weights_only=True) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=2, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + trainer.fit(model) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + trainer.fit(model, ckpt_path=str(tmp_path / "last.ckpt")) + + last_ckpt_2 = torch.load(str(tmp_path / "last.ckpt"), weights_only=True) + + assert compare_state_dicts(last_ckpt_1["loops"], last_ckpt_2["loops"]) == {} + + +def test_restart_from_last_with_val_parity(tmp_path): + model = PredictableBoringModel() + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + save_last=True, + save_top_k=-1, + ) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + limit_val_batches=2, + val_check_interval=2, + ) + trainer.fit(model) + + last_ckpt_1 = torch.load(str(tmp_path / "last.ckpt"), weights_only=True) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=2, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + limit_val_batches=2, + val_check_interval=2, + ) + trainer.fit(model) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + limit_val_batches=2, + val_check_interval=2, + ) + trainer.fit(model, ckpt_path=str(tmp_path / "last.ckpt")) + + last_ckpt_2 = torch.load(str(tmp_path / "last.ckpt"), weights_only=True) + + assert compare_state_dicts(last_ckpt_1["loops"], last_ckpt_2["loops"]) == {} @pytest.mark.parametrize( diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 5a175e181dd9e..685bd6c0bdaef 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -660,8 +660,6 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path): {"name": "train_dataloader"}, {"name": "Callback.on_train_start", "args": (trainer, model)}, {"name": "on_train_start"}, - {"name": "Callback.on_train_epoch_start", "args": (trainer, model)}, - {"name": "on_train_epoch_start"}, *model._train_batch(trainer, model, steps_after_reload, trainer.strategy.root_device, current_batch=1), {"name": "Callback.on_train_epoch_end", "args": (trainer, model)}, {"name": "on_train_epoch_end"}, # before ModelCheckpoint because it's a "monitoring callback"