Skip to content

Commit

Permalink
Fix toggle optimizer (#5775)
Browse files Browse the repository at this point in the history
* Update lightning.py

* update changelog

* add a 3 optimizer test

* resolve flake8

* remove extra code

* typo

* resolve typo

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Feb 4, 2021
1 parent e8c1755 commit 0b7f5a8
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 15 deletions.
15 changes: 14 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [unreleased] - YYYY-MM-DD

### Added

### Changed

### Deprecated

### Removed

### Fixed

- Fixed `toggle_optimizers` not handling all optimizer parameters ([#5775](https://github.com/PyTorchLightning/pytorch-lightning/pull/5775))

## [1.1.7] - 2021-02-03

Expand Down Expand Up @@ -32,7 +45,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed FileNotFoundError for best checkpoint when using DDP with Hydra ([#5629](https://github.com/PyTorchLightning/pytorch-lightning/pull/5629))
- Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620))
- Fixed `Metric`'s `state_dict` not included when child modules ([#5614](https://github.com/PyTorchLightning/pytorch-lightning/pull/5614))
- Fixed Neptune logger creating multiple experiments when GPUs > 1 ([#3256](https://github.com/PyTorchLightning/pytorch-lightning/pull/3256))
- Fixed Neptune logger creating multiple experiments when GPUs > 1 ([#3256](https://github.com/PyTorchLightning/pytorch-lightning/pull/3256))
- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509))
- Fixed tensor printing in `trainer.test()` ([#5138](https://github.com/PyTorchLightning/pytorch-lightning/pull/5138))
- Fixed not using dataloader when `hparams` present ([#4559](https://github.com/PyTorchLightning/pytorch-lightning/pull/4559))
Expand Down
28 changes: 15 additions & 13 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,22 +1176,24 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
optimizer: Current optimizer used in training_loop
optimizer_idx: Current optimizer idx in training_loop
"""

# Iterate over all optimizer parameters to preserve their `requires_grad` information
# in case these are pre-defined during `configure_optimizers`
param_requires_grad_state = {}
# make sure current optimizer is latest to be iterated over.
optimizers = [opt for opt in self.optimizers(use_pl_optimizer=False) if opt != optimizer] + [optimizer]
num_optimizers = len(optimizers) - 1
for opt_idx, opt in enumerate(optimizers):
for opt in self.optimizers(use_pl_optimizer=False):
for group in opt.param_groups:
for param in group['params']:
if num_optimizers == opt_idx:
# If a param appears in 2 optimizers, revert `requires_grad` to before toggle.
if param in param_requires_grad_state:
param.requires_grad = param_requires_grad_state[param]
else:
# save requires_grad for later restoration
param_requires_grad_state[param] = param.requires_grad
param.requires_grad = False

# If a param already appear in param_requires_grad_state, continue
if param in param_requires_grad_state:
continue
param_requires_grad_state[param] = param.requires_grad
param.requires_grad = False

# Then iterate over the current optimizer's parameters and set its `requires_grad`
# properties accordingly
for group in optimizer.param_groups:
for param in group['params']:
param.requires_grad = param_requires_grad_state[param]
self._param_requires_grad_state = param_requires_grad_state

def untoggle_optimizer(self, optimizer_idx: int):
Expand Down
146 changes: 145 additions & 1 deletion tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, clos
trainer.fit(model)


def test_toggle_untoggle(tmpdir):
def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmpdir):

class TestModel(BoringModel):

Expand Down Expand Up @@ -198,8 +198,152 @@ def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, clos
assert self.layer_2[1].weight.requires_grad is False
assert self.layer_2[3].weight.requires_grad is False
assert self.layer_2[5].weight.requires_grad is True

optimizer.step(closure=closure)

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=8,
accumulate_grad_batches=1,
limit_val_batches=0,
)

results = trainer.fit(model)
assert results


def test_toggle_untoggle_3_optimizers_shared_parameters(tmpdir):

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.layer_1 = nn.Sequential(
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
)

self.layer_2 = nn.Sequential(
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 2)
)

self.layer_3 = nn.Sequential(
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 2)
)

# set some weights to False to check untoggle works as expected.
self.layer_1[2].weight.requires_grad = False
self.layer_1[4].weight.requires_grad = False

self.layer_2[1].weight.requires_grad = False
self.layer_2[3].weight.requires_grad = False

self.layer_3[1].weight.requires_grad = False
self.layer_3[5].weight.requires_grad = False

def optimizer_step(
self,
current_epoch,
batch_nb,
optimizer,
optimizer_idx,
closure,
on_tpu=False,
using_native_amp=False,
using_lbfgs=False
):
if optimizer_idx == 0:
assert self.layer_1[0].weight.requires_grad is True
assert self.layer_1[2].weight.requires_grad is False
assert self.layer_1[4].weight.requires_grad is False

assert self.layer_2[1].weight.requires_grad is False
assert self.layer_2[3].weight.requires_grad is False
assert self.layer_2[5].weight.requires_grad is True

assert self.layer_3[1].weight.requires_grad is False
assert self.layer_3[3].weight.requires_grad is False
assert self.layer_3[5].weight.requires_grad is False

if optimizer_idx == 1:
assert self.layer_1[0].weight.requires_grad is False
assert self.layer_1[2].weight.requires_grad is False
assert self.layer_1[4].weight.requires_grad is False

assert self.layer_2[1].weight.requires_grad is False
assert self.layer_2[3].weight.requires_grad is False
assert self.layer_2[5].weight.requires_grad is True

assert self.layer_3[1].weight.requires_grad is False
assert self.layer_3[3].weight.requires_grad is True
assert self.layer_3[5].weight.requires_grad is False

if optimizer_idx == 2:
assert self.layer_1[0].weight.requires_grad is True
assert self.layer_1[2].weight.requires_grad is False
assert self.layer_1[4].weight.requires_grad is False

assert self.layer_2[1].weight.requires_grad is False
assert self.layer_2[3].weight.requires_grad is False
assert self.layer_2[5].weight.requires_grad is False

assert self.layer_3[1].weight.requires_grad is False
assert self.layer_3[3].weight.requires_grad is True
assert self.layer_3[5].weight.requires_grad is False

optimizer.step(closure=closure)

def training_step(self, batch, batch_idx, optimizer_idx=None):
return super().training_step(batch, batch_idx)

@staticmethod
def combine_generators(gen_1, gen_2):
for p in gen_1:
yield p
for p in gen_2:
yield p

def configure_optimizers(self):
optimizer_1 = SGD(
self.combine_generators(
self.layer_1.parameters(),
self.layer_2.parameters()
),
lr=0.1
)
optimizer_2 = Adam(
self.combine_generators(
self.layer_2.parameters(),
self.layer_3.parameters()
),
lr=0.1
)
optimizer_3 = SGD(
self.combine_generators(
self.layer_3.parameters(),
self.layer_1.parameters()
),
lr=0.1
)
return [optimizer_1, optimizer_2, optimizer_3]

model = TestModel()
model.training_epoch_end = None

Expand Down

0 comments on commit 0b7f5a8

Please sign in to comment.