-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into issue/wandb_global_step
- Loading branch information
Showing
17 changed files
with
275 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
r""" | ||
Logging of learning rates | ||
========================= | ||
Log learning rate for lr schedulers during training | ||
""" | ||
|
||
from pytorch_lightning.callbacks.base import Callback | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
|
||
class LearningRateLogger(Callback): | ||
r""" | ||
Automatically logs learning rate for learning rate schedulers during training. | ||
Example:: | ||
>>> from pytorch_lightning import Trainer | ||
>>> from pytorch_lightning.callbacks import LearningRateLogger | ||
>>> lr_logger = LearningRateLogger() | ||
>>> trainer = Trainer(callbacks=[lr_logger]) | ||
Logging names are automatically determined based on optimizer class name. | ||
In case of multiple optimizers of same type, they will be named `Adam`, | ||
`Adam-1` etc. If a optimizer has multiple parameter groups they will | ||
be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a | ||
`name` keyword in the construction of the learning rate schdulers | ||
Example:: | ||
def configure_optimizer(self): | ||
optimizer = torch.optim.Adam(...) | ||
lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...) | ||
'name': 'my_logging_name'} | ||
return [optimizer], [lr_scheduler] | ||
""" | ||
def __init__(self): | ||
self.lrs = None | ||
self.lr_sch_names = [] | ||
|
||
def on_train_start(self, trainer, pl_module): | ||
""" Called before training, determines unique names for all lr | ||
schedulers in the case of multiple of the same type or in | ||
the case of multiple parameter groups | ||
""" | ||
if trainer.lr_schedulers == []: | ||
raise MisconfigurationException( | ||
'Cannot use LearningRateLogger callback with models that have no' | ||
' learning rate schedulers. Please see documentation for' | ||
' `configure_optimizers` method.') | ||
|
||
if not trainer.logger: | ||
raise MisconfigurationException( | ||
'Cannot use LearningRateLogger callback with Trainer that has no logger.') | ||
|
||
# Find names for schedulers | ||
names = self._find_names(trainer.lr_schedulers) | ||
|
||
# Initialize for storing values | ||
self.lrs = dict.fromkeys(names, []) | ||
|
||
def on_batch_start(self, trainer, pl_module): | ||
latest_stat = self._extract_lr(trainer, 'step') | ||
if trainer.logger and latest_stat: | ||
trainer.logger.log_metrics(latest_stat, step=trainer.global_step) | ||
|
||
def on_epoch_start(self, trainer, pl_module): | ||
latest_stat = self._extract_lr(trainer, 'epoch') | ||
if trainer.logger and latest_stat: | ||
trainer.logger.log_metrics(latest_stat, step=trainer.global_step) | ||
|
||
def _extract_lr(self, trainer, interval): | ||
""" Extracts learning rates for lr schedulers and saves information | ||
into dict structure. """ | ||
latest_stat = {} | ||
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers): | ||
if scheduler['interval'] == interval: | ||
param_groups = scheduler['scheduler'].optimizer.param_groups | ||
if len(param_groups) != 1: | ||
for i, pg in enumerate(param_groups): | ||
lr, key = pg['lr'], f'{name}/{i + 1}' | ||
self.lrs[key].append(lr) | ||
latest_stat[key] = lr | ||
else: | ||
self.lrs[name].append(param_groups[0]['lr']) | ||
latest_stat[name] = param_groups[0]['lr'] | ||
return latest_stat | ||
|
||
def _find_names(self, lr_schedulers): | ||
# Create uniqe names in the case we have multiple of the same learning | ||
# rate schduler + multiple parameter groups | ||
names = [] | ||
for scheduler in lr_schedulers: | ||
sch = scheduler['scheduler'] | ||
if 'name' in scheduler: | ||
name = scheduler['name'] | ||
else: | ||
opt_name = 'lr-' + sch.optimizer.__class__.__name__ | ||
i, name = 1, opt_name | ||
# Multiple schduler of the same type | ||
while True: | ||
if name not in names: | ||
break | ||
i, name = i + 1, f'{opt_name}-{i}' | ||
|
||
# Multiple param groups for the same schduler | ||
param_groups = sch.optimizer.param_groups | ||
if len(param_groups) != 1: | ||
for i, pg in enumerate(param_groups): | ||
temp = name + '/pg' + str(i + 1) | ||
names.append(temp) | ||
else: | ||
names.append(name) | ||
|
||
self.lr_sch_names.append(name) | ||
return names |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.