Skip to content

Commit

Permalink
Check if optimizer supports closure (Lightning-AI#4981)
Browse files Browse the repository at this point in the history
* check if optimizer support closure

* cleanup test

* resolve tests

* resolve flake

* update test due to patch limit

* update

* update dep

* Update tests/core/test_lightning_optimizer.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update tests/core/test_lightning_optimizer.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* resolve bug

* update test

* resolve tests

* Update requirements/extra.txt

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* remove bolts dep

* remove bolts

* add missing bolts dep for tests

* remove need for bolts

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 11, 2020
1 parent 4e6a871 commit 7755572
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 22 deletions.
10 changes: 7 additions & 3 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import types
from typing import Any, Callable, Optional
from weakref import proxy
Expand Down Expand Up @@ -60,7 +61,7 @@ def __init__(self,
self._trainer = None
self._optimizer = optimizer
self._accumulate_grad_batches = accumulate_grad_batches
self._automatic_optimization = None
self._support_closure = 'closure' in inspect.signature(optimizer.step).parameters
self._optimizer_idx = None

@property
Expand All @@ -73,7 +74,6 @@ def accumulate_grad_batches(self, accumulate_grad_batches):

def _on_trainer_init(self, trainer):
self._trainer = proxy(trainer)
self._automatic_optimization = trainer.train_loop.automatic_optimization
for opt_idx, opt in enumerate(trainer.optimizers):
if opt == self._optimizer:
self._optimizer_idx = opt_idx
Expand Down Expand Up @@ -111,7 +111,11 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n

else:
with trainer.profiler.profile(profiler_name):
optimizer.step(closure=closure, *args, **kwargs)
if self._support_closure:
optimizer.step(closure=closure, *args, **kwargs)
else:
closure()
optimizer.step(*args, **kwargs)

accelerator_backend = trainer.accelerator_backend
if accelerator_backend is not None and accelerator_backend.rpc_enabled:
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def _module_available(module_path: str) -> bool:
OMEGACONF_AVAILABLE = _module_available("omegaconf")
HYDRA_AVAILABLE = _module_available("hydra")
HOROVOD_AVAILABLE = _module_available("horovod.torch")
BOLTS_AVAILABLE = _module_available("pl_bolts")

TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel')
Expand Down
2 changes: 1 addition & 1 deletion requirements/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility
onnx>=1.7.0
onnxruntime>=1.3.0
hydra-core>=1.0
https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip
https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip
5 changes: 5 additions & 0 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def test_automatic_optimization_num_calls(enable_pl_optimizer, tmpdir):

class TestModel(BoringModel):

def training_step(self, batch, batch_idx, optimizer_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}

def configure_optimizers(self):
optimizer = SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = Adam(self.layer.parameters(), lr=0.1)
Expand Down
74 changes: 62 additions & 12 deletions tests/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
import torch.nn as nn
from torch.optim import Adam, Optimizer

import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
from tests.base.boring_model import BoringModel, RandomDataset, RandomDictDataset, RandomDictStringDataset


def test_lightning_optimizer(tmpdir):
Expand Down Expand Up @@ -80,8 +82,8 @@ def configure_optimizers(self):
assert trainer.optimizers[0].__repr__() == expected


@patch("torch.optim.Adam.step")
@patch("torch.optim.SGD.step")
@patch("torch.optim.Adam.step", autospec=True)
@patch("torch.optim.SGD.step", autospec=True)
def test_lightning_optimizer_manual_optimization(mock_sgd_step, mock_adam_step, tmpdir):
"""
Test that the user can use our LightningOptimizer. Not recommended for now.
Expand All @@ -96,13 +98,13 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
output = self.layer(batch)
loss_1 = self.loss(batch, output)
self.manual_backward(loss_1, opt_1)
opt_1.step(idx="1")
opt_1.step()

def closure():
output = self.layer(batch)
loss_2 = self.loss(batch, output)
self.manual_backward(loss_2, opt_2)
opt_2.step(closure=closure, idx="2")
opt_2.step(closure=closure)

def configure_optimizers(self):
optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
Expand Down Expand Up @@ -133,8 +135,8 @@ def automatic_optimization(self) -> bool:
assert len(mock_adam_step.mock_calls) == 8


@patch("torch.optim.Adam.step")
@patch("torch.optim.SGD.step")
@patch("torch.optim.Adam.step", autospec=True)
@patch("torch.optim.SGD.step", autospec=True)
def test_lightning_optimizer_manual_optimization_and_accumulated_gradients(mock_sgd_step, mock_adam_step, tmpdir):
"""
Test that the user can use our LightningOptimizer. Not recommended.
Expand All @@ -149,13 +151,13 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
output = self.layer(batch)
loss_1 = self.loss(batch, output)
self.manual_backward(loss_1, opt_1)
opt_1.step(idx="1")
opt_1.step()

def closure():
output = self.layer(batch)
loss_2 = self.loss(batch, output)
self.manual_backward(loss_2, opt_2)
opt_2.step(closure=closure, idx="2")
opt_2.step(closure=closure)

def configure_optimizers(self):
optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
Expand Down Expand Up @@ -195,9 +197,8 @@ def test_state(tmpdir):
assert isinstance(lightning_optimizer, Adam)
assert isinstance(lightning_optimizer, Optimizer)
lightning_dict = {}
special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx",
"_trainer", "_use_accumulate_grad_batches_from_trainer", "_automatic_optimization",
"_accumulate_grad_batches"]
special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx", "_support_closure",
"_trainer"]
for k, v in lightning_optimizer.__dict__.items():
if k not in special_attrs:
lightning_dict[k] = v
Expand All @@ -206,6 +207,55 @@ def test_state(tmpdir):
assert optimizer.state == lightning_optimizer.state


def test_lightning_optimizer_with_wrong_optimizer_interface(tmpdir):
class OptimizerWrapper(object):
def __init__(self, optimizer):
self.optim = optimizer
self.state_dict = self.optim.state_dict
self.load_state_dict = self.optim.load_state_dict
self.zero_grad = self.optim.zero_grad
self.add_param_group = self.optim.add_param_group
self.__setstate__ = self.optim.__setstate__
self.__getstate__ = self.optim.__getstate__
self.__repr__ = self.optim.__repr__

@property
def __class__(self):
return Optimizer

@property
def state(self):
return self.optim.state

@property
def param_groups(self):
return self.optim.param_groups

@param_groups.setter
def param_groups(self, value):
self.optim.param_groups = value

def step(self):
# wrongly defined step. Should contain closure
self.optim.step(closure=None)

class TestLightningOptimizerModel(BoringModel):

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.1)
optimizer = OptimizerWrapper(optimizer)
return [optimizer]

model = TestLightningOptimizerModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
weights_summary=None,
log_every_n_steps=1,
)
trainer.fit(model)


def test_lightning_optimizer_automatic_optimization(tmpdir):
"""
Test lightning optimize works with make_optimizer_step in automatic_optimization
Expand Down
11 changes: 5 additions & 6 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def optimizer_closure():
retain_graph = num_backward != backward_idx # noqa E225
self.manual_backward(loss_1, opt, retain_graph=retain_graph)

opt.step(1, closure=optimizer_closure, something="new")
opt.step(closure=optimizer_closure)

def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
Expand Down Expand Up @@ -855,7 +855,7 @@ def automatic_optimization(self) -> bool:
)

trainer.fit(model)
expected_calls = [call(1, closure=ANY, something="new") for s in range(2)]
expected_calls = [call() for s in range(2)]
step_mock.assert_has_calls(expected_calls)


Expand Down Expand Up @@ -902,7 +902,7 @@ def dis_closure():
if batch_idx % 4 == 0 :
# Note: Set make_optimizer_step to True or it will use by default
# Trainer(accumulate_grad_batches=x)
opt_dis.step(closure=dis_closure, make_optimizer_step=True, optim='adam')
opt_dis.step(closure=dis_closure, make_optimizer_step=True)

def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
Expand Down Expand Up @@ -933,10 +933,9 @@ def automatic_optimization(self) -> bool:
)

trainer.fit(model)
expected_calls = [call(closure=ANY, optim='sgd') for s in range(4)]
expected_calls = [call(optim='sgd') for s in range(4)]
mock_sgd_step.assert_has_calls(expected_calls)

expected_calls = [call(closure=ANY, optim='adam') for s in range(2)]
expected_calls = [call() for s in range(2)]
mock_adam_step.assert_has_calls(expected_calls)


Expand Down

0 comments on commit 7755572

Please sign in to comment.