diff --git a/CHANGELOG.md b/CHANGELOG.md index 50d5ef6de2294..5bc8ffcf1d40e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -62,6 +62,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added reset dataloader hooks to Training Plugins and Accelerators ([#7861](https://github.com/PyTorchLightning/pytorch-lightning/pull/7861)) +- Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864)) + + ### Changed - Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 2938feee8339d..d9dacd92dc4d7 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -179,10 +179,6 @@ def batch_to_device( return move_data_to_device(batch, device) - def on_train_start(self) -> None: - """Hook to do something upon the training start""" - pass - def training_step( self, step_kwargs: Dict[str, Union[Any, int]], @@ -348,14 +344,6 @@ def clip_gradients( model=self.model, ) - def on_train_epoch_end(self) -> None: - """Hook to do something on the end of an training epoch.""" - pass - - def on_train_end(self) -> None: - """Hook to do something at the end of the training""" - pass - def setup_optimizers(self, trainer: 'pl.Trainer') -> None: """ Creates optimizers and schedulers @@ -563,3 +551,45 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int: return self.training_type_plugin.update_global_step(total_batch_idx, current_global_step) + + def on_train_epoch_end(self) -> None: + """Hook to do something on the end of an training epoch.""" + pass + + def on_train_start(self) -> None: + """Called when train begins.""" + return self.training_type_plugin.on_train_start() + + def on_validation_start(self) -> None: + """Called when validation begins.""" + return self.training_type_plugin.on_validation_start() + + def on_test_start(self) -> None: + """Called when test begins.""" + return self.training_type_plugin.on_test_start() + + def on_predict_start(self) -> None: + """Called when predict begins.""" + return self.training_type_plugin.on_predict_start() + + def on_validation_end(self) -> None: + """Called when validation ends.""" + return self.training_type_plugin.on_validation_end() + + def on_test_end(self) -> None: + """Called when test end.""" + return self.training_type_plugin.on_test_end() + + def on_predict_end(self) -> None: + """Called when predict ends.""" + return self.training_type_plugin.on_predict_end() + + def on_train_end(self) -> None: + """Called when train ends.""" + return self.training_type_plugin.on_train_end() + + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + """ + Called in the training loop before anything happens for that batch. + """ + return self.training_type_plugin.on_train_batch_start(batch, batch_idx, dataloader_idx) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index f965a56f23219..b57a417e3d774 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -330,3 +330,41 @@ def register_plugins(cls, plugin_registry): def should_rank_save_checkpoint(self) -> bool: """Returns whether the checkpoint should be saved (rank based)""" return self.is_global_zero + + def on_train_start(self) -> None: + """Called when train begins.""" + pass + + def on_validation_start(self) -> None: + """Called when validation begins.""" + pass + + def on_test_start(self) -> None: + """Called when test begins.""" + pass + + def on_predict_start(self) -> None: + """Called when predict begins.""" + pass + + def on_train_end(self) -> None: + """Called when train ends.""" + pass + + def on_validation_end(self) -> None: + """Called when validation ends.""" + pass + + def on_test_end(self) -> None: + """Called when test end.""" + pass + + def on_predict_end(self): + """Called when predict ends.""" + pass + + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + """ + Called in the training loop before anything happens for that batch. + """ + pass