Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update usage of deprecated automatic_optimization #5011

Merged
merged 8 commits into from
Dec 10, 2020
6 changes: 4 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,8 +1394,10 @@ def get_progress_bar_dict(self):

def _verify_is_manual_optimization(self, fn_name):
if self.trainer.train_loop.automatic_optimization:
m = f'to use {fn_name}, please disable automatic optimization: Trainer(automatic_optimization=False)'
raise MisconfigurationException(m)
raise MisconfigurationException(
f'to use {fn_name}, please disable automatic optimization:'
' set model property `automatic_optimization` as False'
)

@classmethod
def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]:
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __verify_train_loop_configuration(self, model):
if trainer.overriden_optimizer_step and not enable_pl_optimizer and automatic_optimization:
rank_zero_warn(
"When overriding `LightningModule` optimizer_step with"
" `Trainer(..., enable_pl_optimizer=False, automatic_optimization=True, ...)`,"
" `Trainer(..., enable_pl_optimizer=False, ...)`,"
" we won't be calling `.zero_grad` we can't assume when you call your `optimizer.step()`."
" For Lightning to take care of it, please use `Trainer(enable_pl_optimizer=True)`."
)
Expand All @@ -90,14 +90,15 @@ def __verify_train_loop_configuration(self, model):
if (has_overriden_optimization_functions) and going_to_accumulate_grad_batches and automatic_optimization:
raise MisconfigurationException(
'When overriding `LightningModule` optimizer_step or optimizer_zero_grad with '
Borda marked this conversation as resolved.
Show resolved Hide resolved
'`Trainer(automatic_optimization=True, ...)`, `accumulate_grad_batches` should to be 1.'
' `Trainer(...)`, `accumulate_grad_batches` should to be 1.'
Borda marked this conversation as resolved.
Show resolved Hide resolved
' It ensures optimizer_step or optimizer_zero_grad are called on every batch.'
)

if (enable_pl_optimizer) and trainer.overriden_optimizer_zero_grad and not automatic_optimization:
raise MisconfigurationException(
'When overriding `LightningModule` optimizer_zero_grad with '
'`Trainer(automatic_optimization=False, enable_pl_optimizer=True, ...) is not supported'
'When overriding `LightningModule` optimizer_zero_grad'
' and preserving model property `automatic_optimization` as True with'
' `Trainer(enable_pl_optimizer=True, ...) is not supported'
)

def __verify_eval_loop_configuration(self, model, eval_loop_name):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def __init__(
)

# init train loop related flags
# TODO: deprecate in 1.2.0
# TODO: remove in 1.3.0
if automatic_optimization is None:
automatic_optimization = True
else:
Expand Down
2 changes: 0 additions & 2 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def optimizer_step(self, *_, **__):
default_root_dir=tmpdir,
limit_train_batches=2,
accumulate_grad_batches=2,
automatic_optimization=True
)

trainer.fit(model)
Expand Down Expand Up @@ -90,7 +89,6 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
default_root_dir=tmpdir,
limit_train_batches=8,
accumulate_grad_batches=1,
automatic_optimization=True,
enable_pl_optimizer=enable_pl_optimizer
)

Expand Down
17 changes: 10 additions & 7 deletions tests/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def configure_optimizers(self):
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1)
return [optimizer_1, optimizer_2], [lr_scheduler]

@property
def automatic_optimization(self) -> bool:
return False

model = TestModel()
model.training_step_end = None
model.training_epoch_end = None
Expand All @@ -121,8 +125,8 @@ def configure_optimizers(self):
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
automatic_optimization=False,
enable_pl_optimizer=True)
enable_pl_optimizer=True,
)
trainer.fit(model)

assert len(mock_sgd_step.mock_calls) == 2
Expand Down Expand Up @@ -161,6 +165,10 @@ def configure_optimizers(self):
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1)
return [optimizer_1, optimizer_2], [lr_scheduler]

@property
def automatic_optimization(self) -> bool:
return False

model = TestModel()
model.training_step_end = None
model.training_epoch_end = None
Expand All @@ -170,7 +178,6 @@ def configure_optimizers(self):
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
automatic_optimization=False,
accumulate_grad_batches=2,
enable_pl_optimizer=True,
)
Expand Down Expand Up @@ -237,7 +244,6 @@ def configure_optimizers(self):
max_epochs=1,
weights_summary=None,
enable_pl_optimizer=True,
automatic_optimization=True
)
trainer.fit(model)

Expand Down Expand Up @@ -291,7 +297,6 @@ def configure_optimizers(self):
max_epochs=1,
weights_summary=None,
enable_pl_optimizer=True,
automatic_optimization=True
)
trainer.fit(model)

Expand Down Expand Up @@ -352,7 +357,6 @@ def configure_optimizers(self):
max_epochs=1,
weights_summary=None,
enable_pl_optimizer=True,
automatic_optimization=True
)
trainer.fit(model)

Expand Down Expand Up @@ -406,7 +410,6 @@ def configure_optimizers(self):
max_epochs=1,
weights_summary=None,
enable_pl_optimizer=True,
automatic_optimization=True,
)
trainer.fit(model)

Expand Down
5 changes: 4 additions & 1 deletion tests/trainer/dynamic_args/test_multiple_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,14 @@ def configure_optimizers(self):
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2

@property
def automatic_optimization(self) -> bool:
return False

model = TestModel()
model.val_dataloader = None

trainer = Trainer(
automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
Expand Down
Loading