Skip to content

Commit

Permalink
[RFC] Add self.lr_schedulers() to LightningModule for manual optimi…
Browse files Browse the repository at this point in the history
…zation (#6567)

* Add test for lr_schedulers()

* Add lr_schedulers to LightningModule

* Update test comment

* Update CHANGELOG
  • Loading branch information
akihironitta authored Apr 9, 2021
1 parent 9c9e2a0 commit 5e4dfd7
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `model` parameter to precision plugins' `clip_gradients` signature ([#6764](https://github.com/PyTorchLightning/pytorch-lightning/pull/6764))


- Added `LightningModule.lr_schedulers()` for manual optimization ([#6567](https://github.com/PyTorchLightning/pytorch-lightning/pull/6567))


### Changed

- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Opt
# multiple opts
return opts

def lr_schedulers(self) -> Optional[Union[Any, List[Any]]]:
if not self.trainer.lr_schedulers:
return None

# ignore other keys "interval", "frequency", etc.
lr_schedulers = [s["scheduler"] for s in self.trainer.lr_schedulers]

# single scheduler
if len(lr_schedulers) == 1:
return lr_schedulers[0]

# multiple schedulers
return lr_schedulers

@property
def example_input_array(self) -> Any:
return self._example_input_array
Expand Down
38 changes: 38 additions & 0 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,3 +1147,41 @@ def dis_closure():
@RunIf(min_gpus=2, special=True)
def test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model(tmpdir):
train_manual_optimization(tmpdir, "ddp", model_cls=TestManualOptimizationDDPModelToggleModel)


def test_lr_schedulers(tmpdir):
"""
Test `lr_schedulers()` returns the same objects
in the same order as `configure_optimizers()` returns.
"""

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx):
scheduler_1, scheduler_2 = self.lr_schedulers()
assert scheduler_1 is self.scheduler_1
assert scheduler_2 is self.scheduler_2

def configure_optimizers(self):
optimizer_1 = torch.optim.SGD(self.parameters(), lr=0.1)
optimizer_2 = torch.optim.SGD(self.parameters(), lr=0.1)
self.scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1)
self.scheduler_2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=1)
return [optimizer_1, optimizer_2], [self.scheduler_1, self.scheduler_2]

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
)

trainer.fit(model)

0 comments on commit 5e4dfd7

Please sign in to comment.