-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Split callbacks in individual files + add a property to Callback for …
…easy trainer instance access
- Loading branch information
Showing
6 changed files
with
450 additions
and
432 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
""" | ||
Callbacks | ||
========= | ||
Callbacks supported by Lightning | ||
""" | ||
|
||
_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization" | ||
|
||
|
||
class Callback(object): | ||
"""Abstract base class used to build new callbacks.""" | ||
|
||
def __init__(self): | ||
self._trainer = None | ||
|
||
@property | ||
def trainer(self): | ||
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG | ||
return self._trainer | ||
|
||
def set_trainer(self, trainer): | ||
"""Make a link to the trainer, so different things like `trainer.current_epoch`, | ||
`trainer.batch_idx`, `trainer.global_step` can be used.""" | ||
self._trainer = trainer | ||
|
||
def on_epoch_begin(self): | ||
"""Called when the epoch begins.""" | ||
pass | ||
|
||
def on_epoch_end(self): | ||
"""Called when the epoch ends.""" | ||
pass | ||
|
||
def on_batch_begin(self): | ||
"""Called when the training batch begins.""" | ||
pass | ||
|
||
def on_batch_end(self): | ||
"""Called when the training batch ends.""" | ||
pass | ||
|
||
def on_train_begin(self): | ||
"""Called when the train begins.""" | ||
pass | ||
|
||
def on_train_end(self): | ||
"""Called when the train ends.""" | ||
pass | ||
|
||
def on_validation_begin(self): | ||
"""Called when the validation loop begins.""" | ||
pass | ||
|
||
def on_validation_end(self): | ||
"""Called when the validation loop ends.""" | ||
pass | ||
|
||
def on_test_begin(self): | ||
"""Called when the test begins.""" | ||
pass | ||
|
||
def on_test_end(self): | ||
"""Called when the test ends.""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import logging as log | ||
import warnings | ||
|
||
import numpy as np | ||
|
||
from .callback import Callback | ||
|
||
|
||
class EarlyStopping(Callback): | ||
r""" | ||
Stop training when a monitored quantity has stopped improving. | ||
Args: | ||
monitor (str): quantity to be monitored. Default: ``'val_loss'``. | ||
min_delta (float): minimum change in the monitored quantity | ||
to qualify as an improvement, i.e. an absolute | ||
change of less than `min_delta`, will count as no | ||
improvement. Default: ``0``. | ||
patience (int): number of epochs with no improvement | ||
after which training will be stopped. Default: ``0``. | ||
verbose (bool): verbosity mode. Default: ``0``. | ||
mode (str): one of {auto, min, max}. In `min` mode, | ||
training will stop when the quantity | ||
monitored has stopped decreasing; in `max` | ||
mode it will stop when the quantity | ||
monitored has stopped increasing; in `auto` | ||
mode, the direction is automatically inferred | ||
from the name of the monitored quantity. Default: ``'auto'``. | ||
strict (bool): whether to crash the training if `monitor` is | ||
not found in the metrics. Default: ``True``. | ||
Example:: | ||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.callbacks import EarlyStopping | ||
early_stopping = EarlyStopping('val_loss') | ||
Trainer(early_stop_callback=early_stopping) | ||
""" | ||
|
||
def __init__(self, monitor='val_loss', | ||
min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True): | ||
super(EarlyStopping, self).__init__() | ||
|
||
self.monitor = monitor | ||
self.patience = patience | ||
self.verbose = verbose | ||
self.strict = strict | ||
self.min_delta = min_delta | ||
self.wait = 0 | ||
self.stopped_epoch = 0 | ||
|
||
if mode not in ['auto', 'min', 'max']: | ||
if self.verbose > 0: | ||
log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.') | ||
mode = 'auto' | ||
|
||
if mode == 'min': | ||
self.monitor_op = np.less | ||
elif mode == 'max': | ||
self.monitor_op = np.greater | ||
else: | ||
if 'acc' in self.monitor: | ||
self.monitor_op = np.greater | ||
else: | ||
self.monitor_op = np.less | ||
|
||
if self.monitor_op == np.greater: | ||
self.min_delta *= 1 | ||
else: | ||
self.min_delta *= -1 | ||
|
||
self.on_train_begin() | ||
|
||
def check_metrics(self, logs): | ||
monitor_val = logs.get(self.monitor) | ||
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`' | ||
f' which is not available. Available metrics are:' | ||
f' `{"`, `".join(list(logs.keys()))}`') | ||
|
||
if monitor_val is None: | ||
if self.strict: | ||
raise RuntimeError(error_msg) | ||
if self.verbose > 0: | ||
warnings.warn(error_msg, RuntimeWarning) | ||
|
||
return False | ||
|
||
return True | ||
|
||
def on_train_begin(self): | ||
# Allow instances to be re-used | ||
self.wait = 0 | ||
self.stopped_epoch = 0 | ||
self.best = np.Inf if self.monitor_op == np.less else -np.Inf | ||
|
||
def on_epoch_end(self): | ||
|
||
logs = self.trainer.callback_metrics | ||
stop_training = False | ||
if not self.check_metrics(logs): | ||
return stop_training | ||
|
||
current = logs.get(self.monitor) | ||
if self.monitor_op(current - self.min_delta, self.best): | ||
self.best = current | ||
self.wait = 0 | ||
else: | ||
self.wait += 1 | ||
if self.wait >= self.patience: | ||
self.stopped_epoch = self.trainer.current_epoch | ||
stop_training = True | ||
self.on_train_end() | ||
|
||
return stop_training | ||
|
||
def on_train_end(self): | ||
if self.stopped_epoch > 0 and self.verbose > 0: | ||
warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,' | ||
' but will start from "0" in v0.8.0.', DeprecationWarning) | ||
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping') |
55 changes: 55 additions & 0 deletions
55
pytorch_lightning/callbacks/gradient_accumulation_scheduler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import warnings | ||
|
||
from .callback import Callback | ||
|
||
|
||
class GradientAccumulationScheduler(Callback): | ||
r""" | ||
Change gradient accumulation factor according to scheduling. | ||
Args: | ||
scheduling (dict): scheduling in format {epoch: accumulation_factor} | ||
.. warning:: Epochs indexing starts from "1" until v0.6.x, but will start from "0" in v0.8.0. | ||
Example:: | ||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.callbacks import GradientAccumulationScheduler | ||
# at epoch 5 start accumulating every 2 batches | ||
accumulator = GradientAccumulationScheduler(scheduling: {5: 2}) | ||
Trainer(accumulate_grad_batches=accumulator) | ||
""" | ||
|
||
def __init__(self, scheduling: dict): | ||
super().__init__() | ||
|
||
if scheduling == {}: # empty dict error | ||
raise TypeError("Empty dict cannot be interpreted correct") | ||
|
||
for key in scheduling.keys(): | ||
if not isinstance(key, int) or not isinstance(scheduling[key], int): | ||
raise TypeError("All epoches and accumulation factor must be integers") | ||
|
||
minimal_epoch = min(scheduling.keys()) | ||
warnings.warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,' | ||
' but will start from "0" in v0.8.0.', DeprecationWarning) | ||
if minimal_epoch < 1: | ||
msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct" | ||
raise IndexError(msg) | ||
if minimal_epoch != 1: # if user didnt define first epoch accumulation factor | ||
scheduling.update({1: 1}) | ||
|
||
self.scheduling = scheduling | ||
self.epochs = sorted(scheduling.keys()) | ||
|
||
def on_epoch_begin(self): | ||
|
||
trainer = self.trainer | ||
# indexing epochs from 1 (until v0.6.x) | ||
# In v0.8.0, ` + 1` should be removed. | ||
epoch = trainer.current_epoch + 1 | ||
for i in reversed(range(len(self.epochs))): | ||
if epoch >= self.epochs[i]: | ||
trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) | ||
break |
Oops, something went wrong.