Skip to content

Commit

Permalink
Learning rate stepping option (#941)
Browse files Browse the repository at this point in the history
* remove deprecated args to learning rate step function

* step based scheduler

* mixing models for testing

* fix styling

* tests

* update documentation

* smaller fix

* update to dict structure

* updated test

* update documentation

* update CHANGELOG.md

* fix styling

* fix problems with trainer io

* fix tests

* simplification of code

* fix styling

* change from batch to step

* update to tests

* fix styling

* fixed some logic

* Update pytorch_lightning/core/lightning.py

* duplicated test

* fix test on amp

* small update to tests

* added monitor key for ReduceLROnPlateau

* Update trainer.py

* Update training_loop.py

* fix test after introducing monitor keyword

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
williamFalcon and Borda authored Mar 5, 2020
1 parent bcb45d9 commit 969e929
Show file tree
Hide file tree
Showing 11 changed files with 377 additions and 32 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849))
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))
- Added support for step-based learning rate scheduling ([#941](https://github.com/PyTorchLightning/pytorch-lightning/pull/941))
- Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
- Checkpoint and early stopping now work without val step ([#1041](https://github.com/PyTorchLightning/pytorch-lightning/pull/1041))

Expand Down
11 changes: 11 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,15 @@ def configure_optimizers(self):
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
return [generator_opt, disriminator_opt], [discriminator_sched]
# example with step-based learning_rate schedulers
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
gen_sched = {'scheduler': ExponentialLR(gen_opt, 0.99),
'interval': 'step'} # called after each training step
dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called after each epoch
return [gen_opt, dis_opt], [gen_sched, dis_sched]
.. note:: Lightning calls .backward() and .step() on each optimizer and learning rate scheduler as needed.
.. note:: If you use 16-bit precision (use_amp=True), Lightning will automatically
Expand All @@ -773,6 +782,8 @@ def configure_optimizers(self):
.. note:: If you need to control how often those optimizers step or override the default .step() schedule,
override the `optimizer_step` hook.
.. note:: If you only want to call a learning rate schduler every `x` step or epoch,
you can input this as 'frequency' key: dict(scheduler=lr_schudler, interval='step' or 'epoch', frequency=x)
"""
return Adam(self.parameters(), lr=1e-3)
Expand Down
55 changes: 42 additions & 13 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from argparse import ArgumentParser

import torch
from torch import optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -743,8 +744,6 @@ def on_train_end(self):
# creates a default one if none passed in
self.configure_early_stopping(early_stop_callback)

self.reduce_lr_on_plateau_scheduler = None

# configure checkpoint callback
self.checkpoint_callback = checkpoint_callback
self.weights_save_path = weights_save_path
Expand Down Expand Up @@ -1079,26 +1078,56 @@ def init_optimizers(
optimizers: Union[Optimizer, Tuple[List, List], List[Optimizer], Tuple[Optimizer]]
) -> Tuple[List, List]:

# single optimizer
# single output, single optimizer
if isinstance(optimizers, Optimizer):
return [optimizers], []

# two lists
if len(optimizers) == 2 and isinstance(optimizers[0], list):
# two lists, optimizer + lr schedulers
elif len(optimizers) == 2 and isinstance(optimizers[0], list):
optimizers, lr_schedulers = optimizers
lr_schedulers, self.reduce_lr_on_plateau_scheduler = self.configure_schedulers(lr_schedulers)
lr_schedulers = self.configure_schedulers(lr_schedulers)
return optimizers, lr_schedulers

# single list or tuple
if isinstance(optimizers, (list, tuple)):
# single list or tuple, multiple optimizer
elif isinstance(optimizers, (list, tuple)):
return optimizers, []

# unknown configuration
else:
raise ValueError('Unknown configuration for model optimizers. Output'
'from model.configure_optimizers() should either be:'
'* single output, single torch.optim.Optimizer'
'* single output, list of torch.optim.Optimizer'
'* two outputs, first being a list of torch.optim.Optimizer',
'second being a list of torch.optim.lr_scheduler')

def configure_schedulers(self, schedulers: list):
for i, scheduler in enumerate(schedulers):
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
reduce_lr_on_plateau_scheduler = schedulers.pop(i)
return schedulers, reduce_lr_on_plateau_scheduler
return schedulers, None
# Convert each scheduler into dict sturcture with relevant information
lr_schedulers = []
default_config = {'interval': 'epoch', # default every epoch
'frequency': 1, # default every epoch/batch
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
'monitor': 'val_loss'} # default value to monitor for ReduceLROnPlateau
for scheduler in schedulers:
if isinstance(scheduler, dict):
if 'scheduler' not in scheduler:
raise ValueError(f'Lr scheduler should have key `scheduler`',
' with item being a lr scheduler')
scheduler['reduce_on_plateau'] = \
isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau)

lr_schedulers.append({**default_config, **scheduler})

elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
lr_schedulers.append({**default_config, 'scheduler': scheduler,
'reduce_on_plateau': True})

elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, 'scheduler': scheduler})
else:
raise ValueError(f'Input {scheduler} to lr schedulers '
'is a invalid input.')
return lr_schedulers

def run_pretrain_routine(self, model: LightningModule):
"""Sanity check a few things before starting actual training.
Expand Down
97 changes: 94 additions & 3 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,94 @@
"""
Lightning can automate saving and loading checkpoints
=====================================================
Checkpointing is enabled by default to the current working directory.
To change the checkpoint path pass in::
Trainer(default_save_path='/your/path/to/save/checkpoints')
To modify the behavior of checkpointing pass in your own callback.
.. code-block:: python
from pytorch_lightning.callbacks import ModelCheckpoint
# DEFAULTS used by the Trainer
checkpoint_callback = ModelCheckpoint(
filepath=os.getcwd(),
save_best_only=True,
verbose=True,
monitor='val_loss',
mode='min',
prefix=''
)
trainer = Trainer(checkpoint_callback=checkpoint_callback)
Restoring training session
--------------------------
You might want to not only load a model but also continue training it. Use this method to
restore the trainer state as well. This will continue from the epoch and global step you last left off.
However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter).
Lightning will restore the session if you pass a logger with the same version and there's a saved checkpoint.
.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TestTubeLogger
logger = TestTubeLogger(
save_dir='./savepath',
version=1 # An existing version with a saved checkpoint
)
trainer = Trainer(
logger=logger,
default_save_path='./savepath'
)
# this fit call loads model weights and trainer state
# the trainer continues seamlessly from where you left off
# without having to do anything else.
trainer.fit(model)
The trainer restores:
- global_step
- current_epoch
- All optimizers
- All lr_schedulers
- Model weights
You can even change the logic of your model as long as the weights and "architecture" of
the system isn't different. If you add a layer, for instance, it might not work.
At a rough level, here's what happens inside Trainer :py:mod:`pytorch_lightning.base_module.model_saving.py`:
.. code-block:: python
self.global_step = checkpoint['global_step']
self.current_epoch = checkpoint['epoch']
# restore the optimizers
optimizer_states = checkpoint['optimizer_states']
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)
# restore the lr schedulers
lr_schedulers = checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
scheduler['scheduler'].load_state_dict(lrs_state)
# uses the model you passed into trainer
model.load_state_dict(checkpoint['state_dict'])
"""

import logging as log
import os
import re
Expand Down Expand Up @@ -228,8 +319,8 @@ def dump_checkpoint(self):

# save lr schedulers
lr_schedulers = []
for i, scheduler in enumerate(self.lr_schedulers):
lr_schedulers.append(scheduler.state_dict())
for scheduler in self.lr_schedulers:
lr_schedulers.append(scheduler['scheduler'].state_dict())

checkpoint['lr_schedulers'] = lr_schedulers

Expand Down Expand Up @@ -320,7 +411,7 @@ def restore_training_state(self, checkpoint):
# restore the lr schedulers
lr_schedulers = checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
scheduler.load_state_dict(lrs_state)
scheduler['scheduler'].load_state_dict(lrs_state)

# ----------------------------------
# PRIVATE OPS
Expand Down
43 changes: 32 additions & 11 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,17 +361,7 @@ def train(self):
self.run_training_epoch()

# update LR schedulers
if self.lr_schedulers is not None:
for lr_scheduler in self.lr_schedulers:
lr_scheduler.step()
if self.reduce_lr_on_plateau_scheduler is not None:
val_loss = self.callback_metrics.get('val_loss')
if val_loss is None:
avail_metrics = ','.join(list(self.callback_metrics.keys()))
m = f'ReduceLROnPlateau conditioned on metric val_loss ' \
f'which is not available. Available metrics are: {avail_metrics}'
raise MisconfigurationException(m)
self.reduce_lr_on_plateau_scheduler.step(val_loss)
self.update_learning_rates(interval='epoch')

if self.max_steps and self.max_steps == self.global_step:
self.run_training_teardown()
Expand Down Expand Up @@ -444,6 +434,9 @@ def run_training_epoch(self):
# when returning -1 from train_step, we end epoch early
early_stop_epoch = batch_result == -1

# update lr
self.update_learning_rates(interval='step')

# ---------------
# RUN VAL STEP
# ---------------
Expand Down Expand Up @@ -716,6 +709,34 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens):

return output

def update_learning_rates(self, interval):
''' Update learning rates
Args:
interval (str): either 'epoch' or 'step'.
'''
if not self.lr_schedulers:
return

for lr_scheduler in self.lr_schedulers:
current_idx = self.batch_idx if interval == 'step' else self.current_epoch
current_idx += 1 # account for both batch and epoch starts from 0
# Take step if call to update_learning_rates matches the interval key and
# the current step modulo the schedulers frequency is zero
if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency'] == 0:
# If instance of ReduceLROnPlateau, we need to pass validation loss
if lr_scheduler['reduce_on_plateau']:
monitor_key = lr_scheduler['monitor']
monitor_val = self.callback_metrics.get(monitor_key)
if monitor_val is None:
avail_metrics = ','.join(list(self.callback_metrics.keys()))
m = f'ReduceLROnPlateau conditioned on metric {monitor_key} ' \
f'which is not available. Available metrics are: {avail_metrics}. ' \
'Condition can be set using `monitor` key in lr scheduler dict'
raise MisconfigurationException(m)
lr_scheduler['scheduler'].step(monitor_val)
else:
lr_scheduler['scheduler'].step()

def call_checkpoint_callback(self):
if self.checkpoint_callback is not None:
self.checkpoint_callback.on_validation_end(self, self.get_model())
Expand Down
3 changes: 3 additions & 0 deletions tests/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
LightValStepFitMultipleDataloadersMixin,
LightTrainDataloader,
LightTestDataloader,
LightTestOptimizerWithSchedulingMixin,
LightTestMultipleOptimizersWithSchedulingMixin,
LightTestOptimizersWithMixedSchedulingMixin
)


Expand Down
2 changes: 1 addition & 1 deletion tests/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def loss(self, labels, logits):
nll = F.nll_loss(logits, labels)
return nll

def training_step(self, batch, batch_idx):
def training_step(self, batch, batch_idx, optimizer_idx=None):
"""
Lightning calls this inside the training loop
:param batch:
Expand Down
41 changes: 40 additions & 1 deletion tests/models/mixins.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import OrderedDict

import torch

from torch import optim
from pytorch_lightning.core.decorators import data_loader


Expand Down Expand Up @@ -598,6 +598,45 @@ def test_end(self, outputs):
return result


class LightTestOptimizerWithSchedulingMixin:
def configure_optimizers(self):
if self.hparams.optimizer_name == 'lbfgs':
optimizer = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate)
else:
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
return [optimizer], [lr_scheduler]


class LightTestMultipleOptimizersWithSchedulingMixin:
def configure_optimizers(self):
if self.hparams.optimizer_name == 'lbfgs':
optimizer1 = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate)
optimizer2 = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate)
else:
optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1)
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)

return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2]


class LightTestOptimizersWithMixedSchedulingMixin:
def configure_optimizers(self):
if self.hparams.optimizer_name == 'lbfgs':
optimizer1 = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate)
optimizer2 = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate)
else:
optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 4, gamma=0.1)
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)

return [optimizer1, optimizer2], \
[{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2]


def _get_output_metric(output, name):
if isinstance(output, dict):
val = output[name]
Expand Down
2 changes: 1 addition & 1 deletion tests/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def run_model_test(trainer_options, model, on_gpu=True):
if trainer.use_ddp or trainer.use_ddp2:
# on hpc this would work fine... but need to hack it for the purpose of the test
trainer.model = pretrained_model
trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()
trainer.optimizers, trainer.lr_schedulers = trainer.init_optimizers(pretrained_model.configure_optimizers())

# test HPC loading / saving
trainer.hpc_save(save_dir, logger)
Expand Down
8 changes: 6 additions & 2 deletions tests/test_gpu_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,14 @@ def test_optimizer_return_options():
assert len(lr_sched) == 0

# opt tuple of lists
opts = ([opt_a], ['lr_scheduler'])
scheduler = torch.optim.lr_scheduler.StepLR(opt_a, 10)
opts = ([opt_a], [scheduler])
optim, lr_sched = trainer.init_optimizers(opts)
assert len(optim) == 1 and len(lr_sched) == 1
assert optim[0] == opts[0][0] and lr_sched[0] == 'lr_scheduler'
assert optim[0] == opts[0][0] and \
lr_sched[0] == dict(scheduler=scheduler, interval='epoch',
frequency=1, reduce_on_plateau=False,
monitor='val_loss')


def test_cpu_slurm_save_load(tmpdir):
Expand Down
Loading

0 comments on commit 969e929

Please sign in to comment.