Skip to content

Commit

Permalink
Ensure restarting from checkpoints leads to consistent internal count…
Browse files Browse the repository at this point in the history
…ers (#20379)

* Fix checkpoint progress for fit loop and batch loop

* Check loss parity

* Rename test

* Fix validation loop handling on restart

* Fix loop reset test

* Avoid skipping to val end if saved mid validation

* Fix type checks in compare state dicts

* Fix edge cases and start from last with and without val

* Clean up

* Formatting

* Avoid running validation when restarting from last

* Fix type annotations

* Fix formatting

* Ensure int max_batch

* Fix condition on batches that stepped

* Remove expected on_train_epoch_start when restarting mid epoch
  • Loading branch information
lantiga authored Nov 13, 2024
1 parent 7038b8d commit 9358898
Show file tree
Hide file tree
Showing 7 changed files with 584 additions and 32 deletions.
38 changes: 37 additions & 1 deletion src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import shutil
import sys
from collections import ChainMap, OrderedDict, defaultdict
from dataclasses import dataclass
from typing import Any, DefaultDict, Iterable, Iterator, List, Optional, Tuple, Union

from lightning_utilities.core.apply_func import apply_to_collection
Expand Down Expand Up @@ -45,6 +46,12 @@
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature


@dataclass
class RestartStage:
NONE = "none"
RESTARTED_MID_EVALUATION = "restarted_mid_evaluation"


class _EvaluationLoop(_Loop):
"""Top-level loop where validation/testing starts."""

Expand Down Expand Up @@ -73,6 +80,7 @@ def __init__(
self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int)
self._last_val_dl_reload_epoch = float("-inf")
self._module_mode = _ModuleMode()
self._restart_stage = RestartStage.NONE

@property
def num_dataloaders(self) -> int:
Expand Down Expand Up @@ -137,7 +145,7 @@ def run(self) -> List[_OUT_DICT]:
# this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
break
finally:
self._restarting = False
self.on_iteration_done()
self._store_dataloader_outputs()
return self.on_run_end()

Expand Down Expand Up @@ -197,6 +205,24 @@ def setup_data(self) -> None:
# this depends on the data used, so reset it too
self._seen_batches_per_dataloader = defaultdict(int)

@property
def restarted_mid_evaluation(self) -> bool:
return self._restart_stage == RestartStage.RESTARTED_MID_EVALUATION

def update_restart_stage(self) -> None:
if (
self.restarting
and self.batch_progress.total.started == self.batch_progress.total.ready
and self.batch_progress.total.processed == self.batch_progress.total.started - 1
and self.batch_progress.total.completed == self.batch_progress.total.processed
):
self._restart_stage = RestartStage.RESTARTED_MID_EVALUATION
else:
self._restart_stage = RestartStage.NONE

def reset_restart_stage(self) -> None:
self._restart_stage = RestartStage.NONE

def reset(self) -> None:
"""Resets the internal state of the loop."""
trainer = self.trainer
Expand Down Expand Up @@ -236,6 +262,16 @@ def reset(self) -> None:
data_fetcher._stop_profiler = self._on_after_fetch
self._data_fetcher = data_fetcher

def increment_progress_to_evaluation_end(self) -> None:
self.setup_data()
if self.skip:
return
self.reset()
max_batch = int(max(self.max_batches))
if max_batch == -1:
return
self.batch_progress.increment_by(max_batch, True)

def on_run_start(self) -> None:
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
hooks."""
Expand Down
107 changes: 99 additions & 8 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import torch
Expand Down Expand Up @@ -45,6 +46,15 @@
log = logging.getLogger(__name__)


@dataclass
class RestartStage:
NONE = "none"
RESTARTED_ON_EPOCH_START = "restarted_on_epoch_start"
RESTARTED_MID_EPOCH = "restarted_mid_epoch"
RESTARTED_ON_EPOCH_END = "restarted_on_epoch_end"
RESUMED_ON_EPOCH_END = "resumed_on_epoch_end"


class _FitLoop(_Loop):
"""This loop is the top-level loop where training starts.
Expand Down Expand Up @@ -97,6 +107,7 @@ def __init__(
self._combined_loader_states_to_load: List[Dict[str, Any]] = []
self._data_fetcher: Optional[_DataFetcher] = None
self._last_train_dl_reload_epoch = float("-inf")
self._restart_stage = RestartStage.NONE

@property
def total_batch_idx(self) -> int:
Expand Down Expand Up @@ -204,9 +215,10 @@ def run(self) -> None:
self.on_advance_start()
self.advance()
self.on_advance_end()
self._restarting = False
except StopIteration:
break
finally:
self.on_iteration_done()
self._restarting = False
self.on_run_end()

Expand Down Expand Up @@ -302,14 +314,92 @@ def setup_data(self) -> None:
category=PossibleUserWarning,
)

@property
def restarted_on_epoch_start(self) -> bool:
return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_START

@property
def restarted_mid_epoch(self) -> bool:
return self._restart_stage == RestartStage.RESTARTED_MID_EPOCH

@property
def restarted_on_epoch_end(self) -> bool:
return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_END

@property
def resumed_on_epoch_end(self) -> bool:
# This case happens when restarting from last without validation at
# the end of epoch. In this case self.restarting is False.
return self._restart_stage == RestartStage.RESUMED_ON_EPOCH_END

def update_restart_stage(self) -> None:
if (
self.restarting
and self.epoch_progress.total.started == self.epoch_progress.total.ready - 1
and self.epoch_progress.total.processed == self.epoch_progress.total.started
and self.epoch_progress.total.completed == self.epoch_progress.total.processed
):
self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_START
elif (
self.restarting
and self.epoch_progress.total.started == self.epoch_progress.total.ready
and self.epoch_progress.total.processed == self.epoch_progress.total.started - 1
and self.epoch_progress.total.completed == self.epoch_progress.total.processed
):
self._restart_stage = RestartStage.RESTARTED_MID_EPOCH
elif (
self.restarting
and self.epoch_progress.total.started == self.epoch_progress.total.ready
and self.epoch_progress.total.processed == self.epoch_progress.total.started
and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1
):
self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_END
elif (
self._loaded_from_state_dict
and self.epoch_progress.total.started == self.epoch_progress.total.ready
and self.epoch_progress.total.processed == self.epoch_progress.total.started
and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1
):
self._restart_stage = RestartStage.RESUMED_ON_EPOCH_END
else:
self._restart_stage = RestartStage.NONE

self.epoch_loop.update_restart_stage()

def reset_restart_stage(self) -> None:
self._restart_stage = RestartStage.NONE

def reset(self) -> None:
"""Resets the internal state of this loop."""
assert self.trainer.model is not None
torch.set_grad_enabled(True)

if self.restarting:
self.update_restart_stage()

if self.restarted_on_epoch_start:
self.epoch_progress.reset_on_restart()

if self.resumed_on_epoch_end:
# when restarting from last without validation at end of epoch,
# self.restarting is False but it's still resuming
self.epoch_progress.increment_completed()

if (
self.epoch_loop.restarted_on_train_batch_end
and self.restarted_mid_epoch
and self.epoch_loop.batch_progress.is_last_batch
):
self.epoch_progress.increment_processed()
self.epoch_progress.increment_completed()

if (
self.epoch_loop.restarted_on_train_batch_end
and self.epoch_loop.batch_progress.is_last_batch
and not self.restarted_mid_epoch
and not self.epoch_loop.val_loop.batch_progress.is_last_batch
):
self.epoch_progress.increment_completed()

def on_run_start(self) -> None:
"""Calls the ``on_train_start`` hook."""
# update the current_epoch in-case of checkpoint reload
Expand Down Expand Up @@ -340,12 +430,14 @@ def on_advance_start(self) -> None:
for i, dl in enumerate(self._combined_loader.flattened):
_set_sampler_epoch(dl, self.epoch_progress.current.processed)

self.epoch_progress.increment_ready()
if not self.restarted_mid_epoch and not self.restarted_on_epoch_end:
if not self.restarted_on_epoch_start:
self.epoch_progress.increment_ready()

call._call_callback_hooks(trainer, "on_train_epoch_start")
call._call_lightning_module_hook(trainer, "on_train_epoch_start")
call._call_callback_hooks(trainer, "on_train_epoch_start")
call._call_lightning_module_hook(trainer, "on_train_epoch_start")

self.epoch_progress.increment_started()
self.epoch_progress.increment_started()

def advance(self) -> None:
"""Runs one whole epoch."""
Expand Down Expand Up @@ -379,8 +471,7 @@ def on_advance_end(self) -> None:

trainer._logger_connector.on_epoch_end()

if self.epoch_loop._num_ready_batches_reached():
# if we are restarting and the above condition holds, it's because we are reloading an epoch-end checkpoint.
if not self.restarting and self.epoch_loop._num_ready_batches_reached():
# since metric-based schedulers require access to metrics and those are not currently saved in the
# checkpoint, the plateau schedulers shouldn't be updated
self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting)
Expand Down
10 changes: 10 additions & 0 deletions src/lightning/pytorch/loops/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class _Loop:

def __init__(self, trainer: "pl.Trainer") -> None:
self._restarting = False
self._loaded_from_state_dict = False
self.trainer = trainer

@property
Expand All @@ -37,6 +38,9 @@ def restarting(self, restarting: bool) -> None:
if isinstance(loop, _Loop):
loop.restarting = restarting

def reset_restart_stage(self) -> None:
pass

def on_save_checkpoint(self) -> Dict:
"""Called when saving a model checkpoint, use to persist loop state.
Expand Down Expand Up @@ -82,6 +86,7 @@ def load_state_dict(
if isinstance(v, _Loop):
v.load_state_dict(state_dict.copy(), prefix + k + ".")
self.restarting = True
self._loaded_from_state_dict = True

def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None:
for k, v in self.__dict__.items():
Expand All @@ -93,3 +98,8 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None:
v.load_state_dict(state_dict[key])
if prefix + "state_dict" in state_dict: # compatibility with old checkpoints
self.on_load_checkpoint(state_dict[prefix + "state_dict"])

def on_iteration_done(self) -> None:
self._restarting = False
self._loaded_from_state_dict = False
self.reset_restart_stage()
22 changes: 22 additions & 0 deletions src/lightning/pytorch/loops/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def reset_on_restart(self) -> None:
"""
self.ready = self.completed

def increment_by(self, n: int) -> None:
self.ready += n
self.completed += n


@dataclass
class _StartedTracker(_ReadyCompletedTracker):
Expand All @@ -94,6 +98,11 @@ def reset_on_restart(self) -> None:
super().reset_on_restart()
self.started = self.completed

@override
def increment_by(self, n: int) -> None:
super().increment_by(n)
self.started += n


@dataclass
class _ProcessedTracker(_StartedTracker):
Expand Down Expand Up @@ -121,6 +130,11 @@ def reset_on_restart(self) -> None:
super().reset_on_restart()
self.processed = self.completed

@override
def increment_by(self, n: int) -> None:
super().increment_by(n)
self.processed += n


@dataclass
class _Progress(_BaseProgress):
Expand Down Expand Up @@ -175,6 +189,10 @@ def reset_on_run(self) -> None:
def reset_on_restart(self) -> None:
self.current.reset_on_restart()

def increment_by(self, n: int) -> None:
self.total.increment_by(n)
self.current.increment_by(n)

@override
def load_state_dict(self, state_dict: dict) -> None:
self.total.load_state_dict(state_dict["total"])
Expand Down Expand Up @@ -206,6 +224,10 @@ def reset_on_run(self) -> None:
super().reset_on_run()
self.is_last_batch = False

def increment_by(self, n: int, is_last_batch: bool = False) -> None:
super().increment_by(n)
self.is_last_batch = is_last_batch

@override
def load_state_dict(self, state_dict: dict) -> None:
super().load_state_dict(state_dict)
Expand Down
Loading

0 comments on commit 9358898

Please sign in to comment.