Skip to content

Commit

Permalink
[IPU] Add hooks for IPU lifecycle 4/5 (#7864)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Naren authored Jun 7, 2021
1 parent ea71cf4 commit 41be61c
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 12 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added reset dataloader hooks to Training Plugins and Accelerators ([#7861](https://github.com/PyTorchLightning/pytorch-lightning/pull/7861))


- Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864))


### Changed

- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)
Expand Down
54 changes: 42 additions & 12 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,6 @@ def batch_to_device(

return move_data_to_device(batch, device)

def on_train_start(self) -> None:
"""Hook to do something upon the training start"""
pass

def training_step(
self,
step_kwargs: Dict[str, Union[Any, int]],
Expand Down Expand Up @@ -348,14 +344,6 @@ def clip_gradients(
model=self.model,
)

def on_train_epoch_end(self) -> None:
"""Hook to do something on the end of an training epoch."""
pass

def on_train_end(self) -> None:
"""Hook to do something at the end of the training"""
pass

def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
"""
Creates optimizers and schedulers
Expand Down Expand Up @@ -563,3 +551,45 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:

def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
return self.training_type_plugin.update_global_step(total_batch_idx, current_global_step)

def on_train_epoch_end(self) -> None:
"""Hook to do something on the end of an training epoch."""
pass

def on_train_start(self) -> None:
"""Called when train begins."""
return self.training_type_plugin.on_train_start()

def on_validation_start(self) -> None:
"""Called when validation begins."""
return self.training_type_plugin.on_validation_start()

def on_test_start(self) -> None:
"""Called when test begins."""
return self.training_type_plugin.on_test_start()

def on_predict_start(self) -> None:
"""Called when predict begins."""
return self.training_type_plugin.on_predict_start()

def on_validation_end(self) -> None:
"""Called when validation ends."""
return self.training_type_plugin.on_validation_end()

def on_test_end(self) -> None:
"""Called when test end."""
return self.training_type_plugin.on_test_end()

def on_predict_end(self) -> None:
"""Called when predict ends."""
return self.training_type_plugin.on_predict_end()

def on_train_end(self) -> None:
"""Called when train ends."""
return self.training_type_plugin.on_train_end()

def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""
Called in the training loop before anything happens for that batch.
"""
return self.training_type_plugin.on_train_batch_start(batch, batch_idx, dataloader_idx)
38 changes: 38 additions & 0 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,41 @@ def register_plugins(cls, plugin_registry):
def should_rank_save_checkpoint(self) -> bool:
"""Returns whether the checkpoint should be saved (rank based)"""
return self.is_global_zero

def on_train_start(self) -> None:
"""Called when train begins."""
pass

def on_validation_start(self) -> None:
"""Called when validation begins."""
pass

def on_test_start(self) -> None:
"""Called when test begins."""
pass

def on_predict_start(self) -> None:
"""Called when predict begins."""
pass

def on_train_end(self) -> None:
"""Called when train ends."""
pass

def on_validation_end(self) -> None:
"""Called when validation ends."""
pass

def on_test_end(self) -> None:
"""Called when test end."""
pass

def on_predict_end(self):
"""Called when predict ends."""
pass

def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""
Called in the training loop before anything happens for that batch.
"""
pass

0 comments on commit 41be61c

Please sign in to comment.