Skip to content

Commit

Permalink
Fix setting up of ReduceLROnPlateau learning rate scheduler (NVIDIA…
Browse files Browse the repository at this point in the history
…#5444)

* Fix tests

Signed-off-by: PeganovAnton <peganoff2@mail.ru>

* Add accidentally lost changes

Signed-off-by: PeganovAnton <peganoff2@mail.ru>

Signed-off-by: PeganovAnton <peganoff2@mail.ru>
Signed-off-by: Hainan Xu <hainanx@nvidia.com>
  • Loading branch information
PeganovAnton authored and Hainan Xu committed Nov 29, 2022
1 parent 0663229 commit 32d0727
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
1 change: 1 addition & 0 deletions nemo/core/config/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,4 +284,5 @@ def get_scheduler_config(name: str, **kwargs: Optional[Dict[str, Any]]) -> Sched
'WarmupAnnealingParams': WarmupAnnealingParams,
'PolynomialDecayAnnealingParams': PolynomialDecayAnnealingParams,
'PolynomialHoldDecayAnnealingParams': PolynomialHoldDecayAnnealingParams,
'ReduceLROnPlateauParams': ReduceLROnPlateauParams,
}
45 changes: 45 additions & 0 deletions tests/core/test_optimizers_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,51 @@ def test_sched_config_parse_from_cls(self):
scheduler_setup = optim.lr_scheduler.prepare_lr_scheduler(opt, dict_config)
assert isinstance(scheduler_setup['scheduler'], optim.lr_scheduler.CosineAnnealing)

@pytest.mark.unit
def test_sched_config_parse_reduce_on_plateau(self):
model = TempModel()
opt_cls = optim.get_optimizer('novograd')
opt = opt_cls(model.parameters(), lr=self.INITIAL_LR)
reduce_on_plateau_parameters = {
'mode': 'min',
'factor': 0.5,
'patience': 1,
'threshold': 1e-4,
'threshold_mode': 'rel',
'min_lr': 1e-6,
'eps': 1e-7,
'verbose': True,
'cooldown': 1,
}
basic_sched_config = {
'name': 'ReduceLROnPlateau',
'monitor': 'val_loss',
'reduce_on_plateau': True,
'max_steps': self.MAX_STEPS,
}
basic_sched_config.update(reduce_on_plateau_parameters)
scheduler_setup = optim.lr_scheduler.prepare_lr_scheduler(opt, basic_sched_config)
assert isinstance(scheduler_setup['scheduler'], torch.optim.lr_scheduler.ReduceLROnPlateau)
for k, v in reduce_on_plateau_parameters.items():
if k == 'min_lr':
k += 's'
v = [v]
found_v = getattr(scheduler_setup['scheduler'], k)
assert (
found_v == v
), f"Wrong value `{repr(found_v)}` for `ReduceLROnPlateau` parameter `{k}`. Expected `{repr(v)}`."
dict_config = omegaconf.OmegaConf.create(basic_sched_config)
scheduler_setup = optim.lr_scheduler.prepare_lr_scheduler(opt, dict_config)
assert isinstance(scheduler_setup['scheduler'], torch.optim.lr_scheduler.ReduceLROnPlateau)
for k, v in reduce_on_plateau_parameters.items():
if k == 'min_lr':
k += 's'
v = [v]
found_v = getattr(scheduler_setup['scheduler'], k)
assert (
found_v == v
), f"Wrong value `{repr(found_v)}` for `ReduceLROnPlateau` parameter `{k}`. Expected `{repr(v)}`."

@pytest.mark.unit
def test_WarmupPolicy(self):
model = TempModel()
Expand Down

0 comments on commit 32d0727

Please sign in to comment.