Skip to content

Commit

Permalink
Split callbacks in individual files + add a property to Callback for …
Browse files Browse the repository at this point in the history
…easy trainer instance access
  • Loading branch information
hadim committed Feb 16, 2020
1 parent bca600e commit 2401027
Show file tree
Hide file tree
Showing 6 changed files with 450 additions and 432 deletions.
7 changes: 6 additions & 1 deletion pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from .pt_callbacks import EarlyStopping, ModelCheckpoint, GradientAccumulationScheduler
from .callback import Callback
from .early_stopping import EarlyStopping
from .model_checkpoint import ModelCheckpoint
from .gradient_accumulation_scheduler import GradientAccumulationScheduler


__all__ = [
'Callback',
'EarlyStopping',
'ModelCheckpoint',
'GradientAccumulationScheduler',
Expand Down
65 changes: 65 additions & 0 deletions pytorch_lightning/callbacks/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Callbacks
=========
Callbacks supported by Lightning
"""

_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization"


class Callback(object):
"""Abstract base class used to build new callbacks."""

def __init__(self):
self._trainer = None

@property
def trainer(self):
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
return self._trainer

def set_trainer(self, trainer):
"""Make a link to the trainer, so different things like `trainer.current_epoch`,
`trainer.batch_idx`, `trainer.global_step` can be used."""
self._trainer = trainer

def on_epoch_begin(self):
"""Called when the epoch begins."""
pass

def on_epoch_end(self):
"""Called when the epoch ends."""
pass

def on_batch_begin(self):
"""Called when the training batch begins."""
pass

def on_batch_end(self):
"""Called when the training batch ends."""
pass

def on_train_begin(self):
"""Called when the train begins."""
pass

def on_train_end(self):
"""Called when the train ends."""
pass

def on_validation_begin(self):
"""Called when the validation loop begins."""
pass

def on_validation_end(self):
"""Called when the validation loop ends."""
pass

def on_test_begin(self):
"""Called when the test begins."""
pass

def on_test_end(self):
"""Called when the test ends."""
pass
121 changes: 121 additions & 0 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import logging as log
import warnings

import numpy as np

from .callback import Callback


class EarlyStopping(Callback):
r"""
Stop training when a monitored quantity has stopped improving.
Args:
monitor (str): quantity to be monitored. Default: ``'val_loss'``.
min_delta (float): minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than `min_delta`, will count as no
improvement. Default: ``0``.
patience (int): number of epochs with no improvement
after which training will be stopped. Default: ``0``.
verbose (bool): verbosity mode. Default: ``0``.
mode (str): one of {auto, min, max}. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity. Default: ``'auto'``.
strict (bool): whether to crash the training if `monitor` is
not found in the metrics. Default: ``True``.
Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
early_stopping = EarlyStopping('val_loss')
Trainer(early_stop_callback=early_stopping)
"""

def __init__(self, monitor='val_loss',
min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True):
super(EarlyStopping, self).__init__()

self.monitor = monitor
self.patience = patience
self.verbose = verbose
self.strict = strict
self.min_delta = min_delta
self.wait = 0
self.stopped_epoch = 0

if mode not in ['auto', 'min', 'max']:
if self.verbose > 0:
log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
mode = 'auto'

if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less

if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1

self.on_train_begin()

def check_metrics(self, logs):
monitor_val = logs.get(self.monitor)
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
f' which is not available. Available metrics are:'
f' `{"`, `".join(list(logs.keys()))}`')

if monitor_val is None:
if self.strict:
raise RuntimeError(error_msg)
if self.verbose > 0:
warnings.warn(error_msg, RuntimeWarning)

return False

return True

def on_train_begin(self):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf

def on_epoch_end(self):

logs = self.trainer.callback_metrics
stop_training = False
if not self.check_metrics(logs):
return stop_training

current = logs.get(self.monitor)
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = self.trainer.current_epoch
stop_training = True
self.on_train_end()

return stop_training

def on_train_end(self):
if self.stopped_epoch > 0 and self.verbose > 0:
warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping')
55 changes: 55 additions & 0 deletions pytorch_lightning/callbacks/gradient_accumulation_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import warnings

from .callback import Callback


class GradientAccumulationScheduler(Callback):
r"""
Change gradient accumulation factor according to scheduling.
Args:
scheduling (dict): scheduling in format {epoch: accumulation_factor}
.. warning:: Epochs indexing starts from "1" until v0.6.x, but will start from "0" in v0.8.0.
Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import GradientAccumulationScheduler
# at epoch 5 start accumulating every 2 batches
accumulator = GradientAccumulationScheduler(scheduling: {5: 2})
Trainer(accumulate_grad_batches=accumulator)
"""

def __init__(self, scheduling: dict):
super().__init__()

if scheduling == {}: # empty dict error
raise TypeError("Empty dict cannot be interpreted correct")

for key in scheduling.keys():
if not isinstance(key, int) or not isinstance(scheduling[key], int):
raise TypeError("All epoches and accumulation factor must be integers")

minimal_epoch = min(scheduling.keys())
warnings.warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
if minimal_epoch < 1:
msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
raise IndexError(msg)
if minimal_epoch != 1: # if user didnt define first epoch accumulation factor
scheduling.update({1: 1})

self.scheduling = scheduling
self.epochs = sorted(scheduling.keys())

def on_epoch_begin(self):

trainer = self.trainer
# indexing epochs from 1 (until v0.6.x)
# In v0.8.0, ` + 1` should be removed.
epoch = trainer.current_epoch + 1
for i in reversed(range(len(self.epochs))):
if epoch >= self.epochs[i]:
trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i])
break
Loading

0 comments on commit 2401027

Please sign in to comment.