From 240102705b00e647ef99b28e5123294ad9b47261 Mon Sep 17 00:00:00 2001 From: Hadrien Mary Date: Sun, 16 Feb 2020 12:27:56 -0500 Subject: [PATCH] Split callbacks in individual files + add a property to Callback for easy trainer instance access --- pytorch_lightning/callbacks/__init__.py | 7 +- pytorch_lightning/callbacks/callback.py | 65 +++ pytorch_lightning/callbacks/early_stopping.py | 121 +++++ .../gradient_accumulation_scheduler.py | 55 +++ .../callbacks/model_checkpoint.py | 203 +++++++++ pytorch_lightning/callbacks/pt_callbacks.py | 431 ------------------ 6 files changed, 450 insertions(+), 432 deletions(-) create mode 100644 pytorch_lightning/callbacks/callback.py create mode 100644 pytorch_lightning/callbacks/early_stopping.py create mode 100644 pytorch_lightning/callbacks/gradient_accumulation_scheduler.py create mode 100644 pytorch_lightning/callbacks/model_checkpoint.py delete mode 100644 pytorch_lightning/callbacks/pt_callbacks.py diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index 9538036563d79f..3d30ebb78849a3 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -1,6 +1,11 @@ -from .pt_callbacks import EarlyStopping, ModelCheckpoint, GradientAccumulationScheduler +from .callback import Callback +from .early_stopping import EarlyStopping +from .model_checkpoint import ModelCheckpoint +from .gradient_accumulation_scheduler import GradientAccumulationScheduler + __all__ = [ + 'Callback', 'EarlyStopping', 'ModelCheckpoint', 'GradientAccumulationScheduler', diff --git a/pytorch_lightning/callbacks/callback.py b/pytorch_lightning/callbacks/callback.py new file mode 100644 index 00000000000000..e17887b3d3d05e --- /dev/null +++ b/pytorch_lightning/callbacks/callback.py @@ -0,0 +1,65 @@ +""" +Callbacks +========= + +Callbacks supported by Lightning +""" + +_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization" + + +class Callback(object): + """Abstract base class used to build new callbacks.""" + + def __init__(self): + self._trainer = None + + @property + def trainer(self): + assert self._trainer is not None, _NO_TRAINER_ERROR_MSG + return self._trainer + + def set_trainer(self, trainer): + """Make a link to the trainer, so different things like `trainer.current_epoch`, + `trainer.batch_idx`, `trainer.global_step` can be used.""" + self._trainer = trainer + + def on_epoch_begin(self): + """Called when the epoch begins.""" + pass + + def on_epoch_end(self): + """Called when the epoch ends.""" + pass + + def on_batch_begin(self): + """Called when the training batch begins.""" + pass + + def on_batch_end(self): + """Called when the training batch ends.""" + pass + + def on_train_begin(self): + """Called when the train begins.""" + pass + + def on_train_end(self): + """Called when the train ends.""" + pass + + def on_validation_begin(self): + """Called when the validation loop begins.""" + pass + + def on_validation_end(self): + """Called when the validation loop ends.""" + pass + + def on_test_begin(self): + """Called when the test begins.""" + pass + + def on_test_end(self): + """Called when the test ends.""" + pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py new file mode 100644 index 00000000000000..118c5530cde5d5 --- /dev/null +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -0,0 +1,121 @@ +import logging as log +import warnings + +import numpy as np + +from .callback import Callback + + +class EarlyStopping(Callback): + r""" + Stop training when a monitored quantity has stopped improving. + + Args: + monitor (str): quantity to be monitored. Default: ``'val_loss'``. + min_delta (float): minimum change in the monitored quantity + to qualify as an improvement, i.e. an absolute + change of less than `min_delta`, will count as no + improvement. Default: ``0``. + patience (int): number of epochs with no improvement + after which training will be stopped. Default: ``0``. + verbose (bool): verbosity mode. Default: ``0``. + mode (str): one of {auto, min, max}. In `min` mode, + training will stop when the quantity + monitored has stopped decreasing; in `max` + mode it will stop when the quantity + monitored has stopped increasing; in `auto` + mode, the direction is automatically inferred + from the name of the monitored quantity. Default: ``'auto'``. + strict (bool): whether to crash the training if `monitor` is + not found in the metrics. Default: ``True``. + + Example:: + + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import EarlyStopping + + early_stopping = EarlyStopping('val_loss') + Trainer(early_stop_callback=early_stopping) + """ + + def __init__(self, monitor='val_loss', + min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True): + super(EarlyStopping, self).__init__() + + self.monitor = monitor + self.patience = patience + self.verbose = verbose + self.strict = strict + self.min_delta = min_delta + self.wait = 0 + self.stopped_epoch = 0 + + if mode not in ['auto', 'min', 'max']: + if self.verbose > 0: + log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.') + mode = 'auto' + + if mode == 'min': + self.monitor_op = np.less + elif mode == 'max': + self.monitor_op = np.greater + else: + if 'acc' in self.monitor: + self.monitor_op = np.greater + else: + self.monitor_op = np.less + + if self.monitor_op == np.greater: + self.min_delta *= 1 + else: + self.min_delta *= -1 + + self.on_train_begin() + + def check_metrics(self, logs): + monitor_val = logs.get(self.monitor) + error_msg = (f'Early stopping conditioned on metric `{self.monitor}`' + f' which is not available. Available metrics are:' + f' `{"`, `".join(list(logs.keys()))}`') + + if monitor_val is None: + if self.strict: + raise RuntimeError(error_msg) + if self.verbose > 0: + warnings.warn(error_msg, RuntimeWarning) + + return False + + return True + + def on_train_begin(self): + # Allow instances to be re-used + self.wait = 0 + self.stopped_epoch = 0 + self.best = np.Inf if self.monitor_op == np.less else -np.Inf + + def on_epoch_end(self): + + logs = self.trainer.callback_metrics + stop_training = False + if not self.check_metrics(logs): + return stop_training + + current = logs.get(self.monitor) + if self.monitor_op(current - self.min_delta, self.best): + self.best = current + self.wait = 0 + else: + self.wait += 1 + if self.wait >= self.patience: + self.stopped_epoch = self.trainer.current_epoch + stop_training = True + self.on_train_end() + + return stop_training + + def on_train_end(self): + if self.stopped_epoch > 0 and self.verbose > 0: + warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,' + ' but will start from "0" in v0.8.0.', DeprecationWarning) + log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping') diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py new file mode 100644 index 00000000000000..d511718b0ac225 --- /dev/null +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -0,0 +1,55 @@ +import warnings + +from .callback import Callback + + +class GradientAccumulationScheduler(Callback): + r""" + Change gradient accumulation factor according to scheduling. + + Args: + scheduling (dict): scheduling in format {epoch: accumulation_factor} + .. warning:: Epochs indexing starts from "1" until v0.6.x, but will start from "0" in v0.8.0. + + Example:: + + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import GradientAccumulationScheduler + + # at epoch 5 start accumulating every 2 batches + accumulator = GradientAccumulationScheduler(scheduling: {5: 2}) + Trainer(accumulate_grad_batches=accumulator) + """ + + def __init__(self, scheduling: dict): + super().__init__() + + if scheduling == {}: # empty dict error + raise TypeError("Empty dict cannot be interpreted correct") + + for key in scheduling.keys(): + if not isinstance(key, int) or not isinstance(scheduling[key], int): + raise TypeError("All epoches and accumulation factor must be integers") + + minimal_epoch = min(scheduling.keys()) + warnings.warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,' + ' but will start from "0" in v0.8.0.', DeprecationWarning) + if minimal_epoch < 1: + msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct" + raise IndexError(msg) + if minimal_epoch != 1: # if user didnt define first epoch accumulation factor + scheduling.update({1: 1}) + + self.scheduling = scheduling + self.epochs = sorted(scheduling.keys()) + + def on_epoch_begin(self): + + trainer = self.trainer + # indexing epochs from 1 (until v0.6.x) + # In v0.8.0, ` + 1` should be removed. + epoch = trainer.current_epoch + 1 + for i in reversed(range(len(self.epochs))): + if epoch >= self.epochs[i]: + trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) + break diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py new file mode 100644 index 00000000000000..9432ff68405445 --- /dev/null +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -0,0 +1,203 @@ +import os +import shutil +import logging as log +import warnings + +import numpy as np + +from .callback import Callback + + +class ModelCheckpoint(Callback): + r""" + + Save the model after every epoch. + + Args: + filepath (str): path to save the model file. + Can contain named formatting options to be auto-filled. + + Example:: + + # save epoch and val_loss in name + ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5') + # saves file like: /path/epoch_2-val_loss_0.2.hdf5 + monitor (str): quantity to monitor. + verbose (bool): verbosity mode, 0 or 1. + save_top_k (int): if `save_top_k == k`, + the best k models according to + the quantity monitored will be saved. + if `save_top_k == 0`, no models are saved. + if `save_top_k == -1`, all models are saved. + Please note that the monitors are checked every `period` epochs. + if `save_top_k >= 2` and the callback is called multiple + times inside an epoch, the name of the saved file will be + appended with a version count starting with `v0`. + mode (str): one of {auto, min, max}. + If `save_top_k != 0`, the decision + to overwrite the current save file is made + based on either the maximization or the + minimization of the monitored quantity. For `val_acc`, + this should be `max`, for `val_loss` this should + be `min`, etc. In `auto` mode, the direction is + automatically inferred from the name of the monitored quantity. + save_weights_only (bool): if True, then only the model's weights will be + saved (`model.save_weights(filepath)`), else the full model + is saved (`model.save(filepath)`). + period (int): Interval (number of epochs) between checkpoints. + + Example:: + + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import ModelCheckpoint + + checkpoint_callback = ModelCheckpoint(filepath='my_path') + Trainer(checkpoint_callback=checkpoint_callback) + + # saves checkpoints to my_path whenever 'val_loss' has a new min + """ + + def __init__(self, filepath, monitor='val_loss', verbose=0, + save_top_k=1, save_weights_only=False, + mode='auto', period=1, prefix=''): + super(ModelCheckpoint, self).__init__() + if ( + save_top_k + and os.path.isdir(filepath) + and len(os.listdir(filepath)) > 0 + ): + warnings.warn( + f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." + "All files in this directory will be deleted when a checkpoint is saved!" + ) + + self.monitor = monitor + self.verbose = verbose + self.filepath = filepath + os.makedirs(filepath, exist_ok=True) + self.save_top_k = save_top_k + self.save_weights_only = save_weights_only + self.period = period + self.epochs_since_last_check = 0 + self.prefix = prefix + self.best_k_models = {} + # {filename: monitor} + self.kth_best_model = '' + self.best = 0 + self.save_function = None + + if mode not in ['auto', 'min', 'max']: + warnings.warn( + f'ModelCheckpoint mode {mode} is unknown, ' + 'fallback to auto mode.', RuntimeWarning) + mode = 'auto' + + if mode == 'min': + self.monitor_op = np.less + self.kth_value = np.Inf + self.mode = 'min' + elif mode == 'max': + self.monitor_op = np.greater + self.kth_value = -np.Inf + self.mode = 'max' + else: + if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): + self.monitor_op = np.greater + self.kth_value = -np.Inf + self.mode = 'max' + else: + self.monitor_op = np.less + self.kth_value = np.Inf + self.mode = 'min' + + def _del_model(self, filepath): + dirpath = os.path.dirname(filepath) + + # make paths + os.makedirs(dirpath, exist_ok=True) + + try: + shutil.rmtree(filepath) + except OSError: + os.remove(filepath) + + def _save_model(self, filepath): + dirpath = os.path.dirname(filepath) + + # make paths + os.makedirs(dirpath, exist_ok=True) + + # delegate the saving to the model + assert self.save_function is not None, ".save_function() not set" + self.save_function(filepath) + + def check_monitor_top_k(self, current): + less_than_k_models = len(self.best_k_models.keys()) < self.save_top_k + if less_than_k_models: + return True + return self.monitor_op(current, self.best_k_models[self.kth_best_model]) + + def on_validation_end(self): + + logs = self.trainer.callback_metrics + epoch = self.trainer.current_epoch + self.epochs_since_last_check += 1 + + if self.save_top_k == 0: + # no models are saved + return + if self.epochs_since_last_check >= self.period: + self.epochs_since_last_check = 0 + filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt' + version_cnt = 0 + while os.path.isfile(filepath): + # this epoch called before + filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt' + version_cnt += 1 + + if self.save_top_k != -1: + current = logs.get(self.monitor) + + if current is None: + warnings.warn( + f'Can save best model only with {self.monitor} available,' + ' skipping.', RuntimeWarning) + else: + if self.check_monitor_top_k(current): + + # remove kth + if len(self.best_k_models.keys()) == self.save_top_k: + delpath = self.kth_best_model + self.best_k_models.pop(self.kth_best_model) + self._del_model(delpath) + + self.best_k_models[filepath] = current + if len(self.best_k_models.keys()) == self.save_top_k: + # monitor dict has reached k elements + if self.mode == 'min': + self.kth_best_model = max(self.best_k_models, key=self.best_k_models.get) + else: + self.kth_best_model = min(self.best_k_models, key=self.best_k_models.get) + self.kth_value = self.best_k_models[self.kth_best_model] + + if self.mode == 'min': + self.best = min(self.best_k_models.values()) + else: + self.best = max(self.best_k_models.values()) + if self.verbose > 0: + log.info( + f'\nEpoch {epoch:05d}: {self.monitor} reached' + f' {current:0.5f} (best {self.best:0.5f}), saving model to' + f' {filepath} as top {self.save_top_k}') + self._save_model(filepath) + + else: + if self.verbose > 0: + log.info( + f'\nEpoch {epoch:05d}: {self.monitor}' + f' was not in top {self.save_top_k}') + + else: + if self.verbose > 0: + log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}') + self._save_model(filepath) diff --git a/pytorch_lightning/callbacks/pt_callbacks.py b/pytorch_lightning/callbacks/pt_callbacks.py deleted file mode 100644 index 8fe557ceb701b3..00000000000000 --- a/pytorch_lightning/callbacks/pt_callbacks.py +++ /dev/null @@ -1,431 +0,0 @@ -""" -Callbacks -========= - -Callbacks supported by Lightning -""" - -import os -import shutil -import logging as log -import warnings - -import numpy as np - - -class Callback(object): - """Abstract base class used to build new callbacks.""" - - def __init__(self): - self._trainer = None - - def set_trainer(self, trainer): - """Make a link to the trainer, so different things like `trainer.current_epoch`, - `trainer.batch_idx`, `trainer.global_step` can be used.""" - self._trainer = trainer - - def on_epoch_begin(self): - """Called when the epoch begins.""" - pass - - def on_epoch_end(self): - """Called when the epoch ends.""" - pass - - def on_batch_begin(self): - """Called when the training batch begins.""" - pass - - def on_batch_end(self): - """Called when the training batch ends.""" - pass - - def on_train_begin(self): - """Called when the train begins.""" - pass - - def on_train_end(self): - """Called when the train ends.""" - pass - - def on_validation_begin(self): - """Called when the validation loop begins.""" - pass - - def on_validation_end(self): - """Called when the validation loop ends.""" - pass - - def on_test_begin(self): - """Called when the test begins.""" - pass - - def on_test_end(self): - """Called when the test ends.""" - pass - - -_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization" - - -class EarlyStopping(Callback): - r""" - Stop training when a monitored quantity has stopped improving. - - Args: - monitor (str): quantity to be monitored. Default: ``'val_loss'``. - min_delta (float): minimum change in the monitored quantity - to qualify as an improvement, i.e. an absolute - change of less than `min_delta`, will count as no - improvement. Default: ``0``. - patience (int): number of epochs with no improvement - after which training will be stopped. Default: ``0``. - verbose (bool): verbosity mode. Default: ``0``. - mode (str): one of {auto, min, max}. In `min` mode, - training will stop when the quantity - monitored has stopped decreasing; in `max` - mode it will stop when the quantity - monitored has stopped increasing; in `auto` - mode, the direction is automatically inferred - from the name of the monitored quantity. Default: ``'auto'``. - strict (bool): whether to crash the training if `monitor` is - not found in the metrics. Default: ``True``. - - Example:: - - from pytorch_lightning import Trainer - from pytorch_lightning.callbacks import EarlyStopping - - early_stopping = EarlyStopping('val_loss') - Trainer(early_stop_callback=early_stopping) - """ - - def __init__(self, monitor='val_loss', - min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True): - super(EarlyStopping, self).__init__() - - self.monitor = monitor - self.patience = patience - self.verbose = verbose - self.strict = strict - self.min_delta = min_delta - self.wait = 0 - self.stopped_epoch = 0 - - if mode not in ['auto', 'min', 'max']: - if self.verbose > 0: - log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.') - mode = 'auto' - - if mode == 'min': - self.monitor_op = np.less - elif mode == 'max': - self.monitor_op = np.greater - else: - if 'acc' in self.monitor: - self.monitor_op = np.greater - else: - self.monitor_op = np.less - - if self.monitor_op == np.greater: - self.min_delta *= 1 - else: - self.min_delta *= -1 - - self.on_train_begin() - - def check_metrics(self, logs): - monitor_val = logs.get(self.monitor) - error_msg = (f'Early stopping conditioned on metric `{self.monitor}`' - f' which is not available. Available metrics are:' - f' `{"`, `".join(list(logs.keys()))}`') - - if monitor_val is None: - if self.strict: - raise RuntimeError(error_msg) - if self.verbose > 0: - warnings.warn(error_msg, RuntimeWarning) - - return False - - return True - - def on_train_begin(self): - # Allow instances to be re-used - self.wait = 0 - self.stopped_epoch = 0 - self.best = np.Inf if self.monitor_op == np.less else -np.Inf - - def on_epoch_end(self): - assert self._trainer is not None, _NO_TRAINER_ERROR_MSG - - logs = self._trainer.callback_metrics - stop_training = False - if not self.check_metrics(logs): - return stop_training - - current = logs.get(self.monitor) - if self.monitor_op(current - self.min_delta, self.best): - self.best = current - self.wait = 0 - else: - self.wait += 1 - if self.wait >= self.patience: - self.stopped_epoch = self._trainer.current_epoch - stop_training = True - self.on_train_end() - - return stop_training - - def on_train_end(self): - if self.stopped_epoch > 0 and self.verbose > 0: - warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,' - ' but will start from "0" in v0.8.0.', DeprecationWarning) - log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping') - - -class ModelCheckpoint(Callback): - r""" - - Save the model after every epoch. - - Args: - filepath (str): path to save the model file. - Can contain named formatting options to be auto-filled. - - Example:: - - # save epoch and val_loss in name - ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5') - # saves file like: /path/epoch_2-val_loss_0.2.hdf5 - monitor (str): quantity to monitor. - verbose (bool): verbosity mode, 0 or 1. - save_top_k (int): if `save_top_k == k`, - the best k models according to - the quantity monitored will be saved. - if `save_top_k == 0`, no models are saved. - if `save_top_k == -1`, all models are saved. - Please note that the monitors are checked every `period` epochs. - if `save_top_k >= 2` and the callback is called multiple - times inside an epoch, the name of the saved file will be - appended with a version count starting with `v0`. - mode (str): one of {auto, min, max}. - If `save_top_k != 0`, the decision - to overwrite the current save file is made - based on either the maximization or the - minimization of the monitored quantity. For `val_acc`, - this should be `max`, for `val_loss` this should - be `min`, etc. In `auto` mode, the direction is - automatically inferred from the name of the monitored quantity. - save_weights_only (bool): if True, then only the model's weights will be - saved (`model.save_weights(filepath)`), else the full model - is saved (`model.save(filepath)`). - period (int): Interval (number of epochs) between checkpoints. - - Example:: - - from pytorch_lightning import Trainer - from pytorch_lightning.callbacks import ModelCheckpoint - - checkpoint_callback = ModelCheckpoint(filepath='my_path') - Trainer(checkpoint_callback=checkpoint_callback) - - # saves checkpoints to my_path whenever 'val_loss' has a new min - """ - - def __init__(self, filepath, monitor='val_loss', verbose=0, - save_top_k=1, save_weights_only=False, - mode='auto', period=1, prefix=''): - super(ModelCheckpoint, self).__init__() - if ( - save_top_k and - os.path.isdir(filepath) and - len(os.listdir(filepath)) > 0 - ): - warnings.warn( - f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." - "All files in this directory will be deleted when a checkpoint is saved!" - ) - - self.monitor = monitor - self.verbose = verbose - self.filepath = filepath - os.makedirs(filepath, exist_ok=True) - self.save_top_k = save_top_k - self.save_weights_only = save_weights_only - self.period = period - self.epochs_since_last_check = 0 - self.prefix = prefix - self.best_k_models = {} - # {filename: monitor} - self.kth_best_model = '' - self.best = 0 - - if mode not in ['auto', 'min', 'max']: - warnings.warn( - f'ModelCheckpoint mode {mode} is unknown, ' - 'fallback to auto mode.', RuntimeWarning) - mode = 'auto' - - if mode == 'min': - self.monitor_op = np.less - self.kth_value = np.Inf - self.mode = 'min' - elif mode == 'max': - self.monitor_op = np.greater - self.kth_value = -np.Inf - self.mode = 'max' - else: - if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): - self.monitor_op = np.greater - self.kth_value = -np.Inf - self.mode = 'max' - else: - self.monitor_op = np.less - self.kth_value = np.Inf - self.mode = 'min' - - def _del_model(self, filepath): - dirpath = os.path.dirname(filepath) - - # make paths - os.makedirs(dirpath, exist_ok=True) - - try: - shutil.rmtree(filepath) - except OSError: - os.remove(filepath) - - def _save_model(self, filepath): - dirpath = os.path.dirname(filepath) - - # make paths - os.makedirs(dirpath, exist_ok=True) - - # delegate the saving to the model - self.save_function(filepath) - - def check_monitor_top_k(self, current): - less_than_k_models = len(self.best_k_models.keys()) < self.save_top_k - if less_than_k_models: - return True - return self.monitor_op(current, self.best_k_models[self.kth_best_model]) - - def on_validation_end(self): - assert self._trainer is not None, _NO_TRAINER_ERROR_MSG - - logs = self._trainer.callback_metrics - epoch = self._trainer.current_epoch - self.epochs_since_last_check += 1 - - if self.save_top_k == 0: - # no models are saved - return - if self.epochs_since_last_check >= self.period: - self.epochs_since_last_check = 0 - filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt' - version_cnt = 0 - while os.path.isfile(filepath): - # this epoch called before - filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt' - version_cnt += 1 - - if self.save_top_k != -1: - current = logs.get(self.monitor) - - if current is None: - warnings.warn( - f'Can save best model only with {self.monitor} available,' - ' skipping.', RuntimeWarning) - else: - if self.check_monitor_top_k(current): - - # remove kth - if len(self.best_k_models.keys()) == self.save_top_k: - delpath = self.kth_best_model - self.best_k_models.pop(self.kth_best_model) - self._del_model(delpath) - - self.best_k_models[filepath] = current - if len(self.best_k_models.keys()) == self.save_top_k: - # monitor dict has reached k elements - if self.mode == 'min': - self.kth_best_model = max(self.best_k_models, key=self.best_k_models.get) - else: - self.kth_best_model = min(self.best_k_models, key=self.best_k_models.get) - self.kth_value = self.best_k_models[self.kth_best_model] - - if self.mode == 'min': - self.best = min(self.best_k_models.values()) - else: - self.best = max(self.best_k_models.values()) - if self.verbose > 0: - log.info( - f'\nEpoch {epoch:05d}: {self.monitor} reached' - f' {current:0.5f} (best {self.best:0.5f}), saving model to' - f' {filepath} as top {self.save_top_k}') - self._save_model(filepath) - - else: - if self.verbose > 0: - log.info( - f'\nEpoch {epoch:05d}: {self.monitor}' - f' was not in top {self.save_top_k}') - - else: - if self.verbose > 0: - log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}') - self._save_model(filepath) - - -class GradientAccumulationScheduler(Callback): - r""" - Change gradient accumulation factor according to scheduling. - - Args: - scheduling (dict): scheduling in format {epoch: accumulation_factor} - .. warning:: Epochs indexing starts from "1" until v0.6.x, but will start from "0" in v0.8.0. - - Example:: - - from pytorch_lightning import Trainer - from pytorch_lightning.callbacks import GradientAccumulationScheduler - - # at epoch 5 start accumulating every 2 batches - accumulator = GradientAccumulationScheduler(scheduling: {5: 2}) - Trainer(accumulate_grad_batches=accumulator) - """ - - def __init__(self, scheduling: dict): - super().__init__() - - if scheduling == {}: # empty dict error - raise TypeError("Empty dict cannot be interpreted correct") - - for key in scheduling.keys(): - if not isinstance(key, int) or not isinstance(scheduling[key], int): - raise TypeError("All epoches and accumulation factor must be integers") - - minimal_epoch = min(scheduling.keys()) - warnings.warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,' - ' but will start from "0" in v0.8.0.', DeprecationWarning) - if minimal_epoch < 1: - msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct" - raise IndexError(msg) - if minimal_epoch != 1: # if user didnt define first epoch accumulation factor - scheduling.update({1: 1}) - - self.scheduling = scheduling - self.epochs = sorted(scheduling.keys()) - - def on_epoch_begin(self): - assert self._trainer is not None, _NO_TRAINER_ERROR_MSG - - trainer = self._trainer - # indexing epochs from 1 (until v0.6.x) - # In v0.8.0, ` + 1` should be removed. - epoch = trainer.current_epoch + 1 - for i in reversed(range(len(self.epochs))): - if epoch >= self.epochs[i]: - trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) - break