forked from Lightning-AI/pytorch-lightning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove default optimizer, add None optimizer option (Lightning-AI#1279)
* Add warning when using default optimizer * Refactor optimizer tests to test_optimizers * Remove default optimizer, add option to use no optimizer * Update CHANGELOG.md * Update pytorch_lightning/trainer/optimizers.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Fix style Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
- Loading branch information
1 parent
c0199e5
commit fc59fca
Showing
11 changed files
with
267 additions
and
168 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import warnings | ||
from abc import ABC | ||
from typing import List, Tuple | ||
|
||
import torch | ||
from torch import optim | ||
from torch.optim.optimizer import Optimizer | ||
|
||
from pytorch_lightning.core.lightning import LightningModule | ||
|
||
|
||
class TrainerOptimizersMixin(ABC): | ||
|
||
def init_optimizers( | ||
self, | ||
model: LightningModule | ||
) -> Tuple[List, List, List]: | ||
optim_conf = model.configure_optimizers() | ||
|
||
if optim_conf is None: | ||
warnings.warn('`LightningModule.configure_optimizers` returned `None`, ' | ||
'this fit will run with no optimizer', UserWarning) | ||
optim_conf = _MockOptimizer() | ||
|
||
# single output, single optimizer | ||
if isinstance(optim_conf, Optimizer): | ||
return [optim_conf], [], [] | ||
|
||
# two lists, optimizer + lr schedulers | ||
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \ | ||
and isinstance(optim_conf[0], list): | ||
optimizers, lr_schedulers = optim_conf | ||
lr_schedulers = self.configure_schedulers(lr_schedulers) | ||
return optimizers, lr_schedulers, [] | ||
|
||
# single dictionary | ||
elif isinstance(optim_conf, dict): | ||
optimizer = optim_conf["optimizer"] | ||
lr_scheduler = optim_conf.get("lr_scheduler", []) | ||
if lr_scheduler: | ||
lr_schedulers = self.configure_schedulers([lr_scheduler]) | ||
return [optimizer], lr_schedulers, [] | ||
|
||
# multiple dictionaries | ||
elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict): | ||
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] | ||
# take only lr wif exists and ot they are defined - not None | ||
lr_schedulers = [ | ||
opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler") | ||
] | ||
# take only freq wif exists and ot they are defined - not None | ||
optimizer_frequencies = [ | ||
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency") | ||
] | ||
|
||
# clean scheduler list | ||
if lr_schedulers: | ||
lr_schedulers = self.configure_schedulers(lr_schedulers) | ||
# assert that if frequencies are present, they are given for all optimizers | ||
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers): | ||
raise ValueError("A frequency must be given to each optimizer.") | ||
return optimizers, lr_schedulers, optimizer_frequencies | ||
|
||
# single list or tuple, multiple optimizer | ||
elif isinstance(optim_conf, (list, tuple)): | ||
return list(optim_conf), [], [] | ||
|
||
# unknown configuration | ||
else: | ||
raise ValueError( | ||
'Unknown configuration for model optimizers.' | ||
' Output from `model.configure_optimizers()` should either be:' | ||
' * single output, single `torch.optim.Optimizer`' | ||
' * single output, list of `torch.optim.Optimizer`' | ||
' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)' | ||
' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)' | ||
' * two outputs, first being a list of `torch.optim.Optimizer` second being' | ||
' a list of `torch.optim.lr_scheduler`' | ||
' * multiple outputs, dictionaries as described with an optional `frequency` key (int)') | ||
|
||
def configure_schedulers(self, schedulers: list): | ||
# Convert each scheduler into dict sturcture with relevant information | ||
lr_schedulers = [] | ||
default_config = {'interval': 'epoch', # default every epoch | ||
'frequency': 1, # default every epoch/batch | ||
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler | ||
'monitor': 'val_loss'} # default value to monitor for ReduceLROnPlateau | ||
for scheduler in schedulers: | ||
if isinstance(scheduler, dict): | ||
if 'scheduler' not in scheduler: | ||
raise ValueError(f'Lr scheduler should have key `scheduler`', | ||
' with item being a lr scheduler') | ||
scheduler['reduce_on_plateau'] = isinstance( | ||
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau) | ||
|
||
lr_schedulers.append({**default_config, **scheduler}) | ||
|
||
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): | ||
lr_schedulers.append({**default_config, 'scheduler': scheduler, | ||
'reduce_on_plateau': True}) | ||
|
||
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): | ||
lr_schedulers.append({**default_config, 'scheduler': scheduler}) | ||
else: | ||
raise ValueError(f'Input {scheduler} to lr schedulers ' | ||
'is a invalid input.') | ||
return lr_schedulers | ||
|
||
|
||
class _MockOptimizer(Optimizer): | ||
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` | ||
is returned from `configure_optimizers`. | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__([torch.zeros(1)], {}) | ||
|
||
def add_param_group(self, param_group): | ||
pass # Do Nothing | ||
|
||
def load_state_dict(self, state_dict): | ||
pass # Do Nothing | ||
|
||
def state_dict(self): | ||
return {} # Return Empty | ||
|
||
def step(self, closure=None): | ||
if closure is not None: | ||
closure() | ||
|
||
def zero_grad(self): | ||
pass # Do Nothing | ||
|
||
def __repr__(self): | ||
return 'No Optimizer' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.