-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ensure restarting from checkpoints leads to consistent internal counters #20379
Conversation
⚡ Required checks status: All passing 🟢Groups summary🟢 pytorch_lightning: Tests workflowThese checks are required after the changes to 🟢 pytorch_lightning: Azure GPU
These checks are required after the changes to 🟢 pytorch_lightning: Benchmarks
These checks are required after the changes to 🟢 pytorch_lightning: Docs
These checks are required after the changes to 🟢 mypy
These checks are required after the changes to 🟢 installThese checks are required after the changes to Thank you for your contribution! 💜
|
5e313ee
to
0012dcb
Compare
d303d27
to
c3469be
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #20379 +/- ##
=========================================
- Coverage 89% 81% -8%
=========================================
Files 267 264 -3
Lines 23070 23147 +77
=========================================
- Hits 20579 18739 -1840
- Misses 2491 4408 +1917 |
…ers (Lightning-AI#20379) * Fix checkpoint progress for fit loop and batch loop * Check loss parity * Rename test * Fix validation loop handling on restart * Fix loop reset test * Avoid skipping to val end if saved mid validation * Fix type checks in compare state dicts * Fix edge cases and start from last with and without val * Clean up * Formatting * Avoid running validation when restarting from last * Fix type annotations * Fix formatting * Ensure int max_batch * Fix condition on batches that stepped * Remove expected on_train_epoch_start when restarting mid epoch
- removed `_maybe_sync_loops` after Lightning-AI/pytorch-lightning#20379 obviated the need for it
Does this also fix the same issue appearing when |
What does this PR do?
Fixes #14579
The following code
will produce skewed progress information in the checkpoints, compared to the case where there is no restart.
This is due to the fact that when
ModelCheckpoint
is triggered onon_train_batch_end
, it won't seebatch_progress.total.completed
updated to the latest batch that was processed, because progress is updated right after the hook is called.However, upon restart, there won't be any opportunity to register the actual completion of the batch, causing a skew that is proportional to the number of restarts. This impacts the time at which validation is called, which itself becomes dependent from restarts.
This PR addresses this issue by reconciling progress upon restart.
It adds tests and tightens the behavior when restarting in multiple cases, namely when checkpoints are saved:
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--20379.org.readthedocs.build/en/20379/
cc @Borda