Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the condition for calling update_learning_rates #7032

Merged
merged 7 commits into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Remove hardcoding of local rank in accelerator connector ([#6878](https://github.com/PyTorchLightning/pytorch-lightning/pull/6878))


- Fixed incorrect number of calls to LR scheduler when `check_val_every_n_epoch > 1` ([#7032](https://github.com/PyTorchLightning/pytorch-lightning/pull/7032))


## [1.2.7] - 2021-04-06

### Fixed
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,6 @@ def run_training_epoch(self):

train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader_idx = 0
val_loop_called = False

batch_idx = None
is_last_batch = None
Expand Down Expand Up @@ -514,7 +513,6 @@ def run_training_epoch(self):
self.trainer.validating = True
self.trainer._run_evaluation()
self.trainer.training = True
val_loop_called = True

# -----------------------------------------
# SAVE LOGGERS (ie: Tensorboard, etc...)
Expand Down Expand Up @@ -563,7 +561,7 @@ def run_training_epoch(self):
should_train_only = self.trainer.disable_validation or should_skip_eval

# update epoch level lr_schedulers if no val loop outside train loop is triggered
if (val_loop_called and not should_check_val) or should_train_only:
if not should_check_val or should_train_only:
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')

if should_train_only:
Expand Down
42 changes: 31 additions & 11 deletions tests/trainer/optimization/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock

import pytest
import torch
from torch import optim
Expand Down Expand Up @@ -577,21 +579,21 @@ def configure_optimizers(self):
trainer.fit(model)


class TestModel(BoringModel):
@RunIf(min_gpus=2, special=True)
def test_optimizer_state_on_device(tmpdir):
""" Test that optimizers that create state initially at instantiation still end up with the state on the GPU. """

def configure_optimizers(self):
# Adagrad creates state tensors immediately, model is not yet on GPU.
return optim.Adagrad(self.parameters())
class TestModel(BoringModel):

def on_train_start(self, *args, **kwargs):
opt = self.optimizers()
_, state = next(iter(opt.state.items()))
assert state["sum"].device == torch.device("cuda", self.local_rank) == self.device
def configure_optimizers(self):
# Adagrad creates state tensors immediately, model is not yet on GPU.
return optim.Adagrad(self.parameters())

def on_train_start(self, *args, **kwargs):
opt = self.optimizers()
_, state = next(iter(opt.state.items()))
assert state["sum"].device == torch.device("cuda", self.local_rank) == self.device

@RunIf(min_gpus=2, special=True)
def test_optimizer_state_on_device(tmpdir):
""" Test that optimizers that create state initially at instantiation still end up with the state on the GPU. """
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
Expand All @@ -600,3 +602,21 @@ def test_optimizer_state_on_device(tmpdir):
fast_dev_run=True,
)
trainer.fit(model)


@pytest.mark.parametrize("check_val_every_n_epoch", [1, 2])
@mock.patch("torch.optim.lr_scheduler.StepLR.step")
def test_lr_scheduler_epoch_step_frequency(mocked_sched, check_val_every_n_epoch, tmpdir):
epochs = 4
expected_steps = epochs + 1 # every LRScheduler gets called once at init

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
check_val_every_n_epoch=check_val_every_n_epoch,
max_epochs=epochs,
)
trainer.fit(model)
assert mocked_sched.call_count == expected_steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob question: why do we want to update schedulers all epochs even if we only run validation some of them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

running validation on some epoch is just a choice. Ideally, a scheduler should update after every epoch if scheduler['frequency] = 1(default) and scheduler['interval'] == 'epoch'. If someone wants it to align it with validation, they can set scheduler['frequency'] = check_val_every_n_epoch.