Skip to content

Commit

Permalink
Additional hooks (#598)
Browse files Browse the repository at this point in the history
* Renamed `on_sanity_check_start` to `on_train_start` and added `on_train_end` to `ModelHooks`

* changed tests to use `on_train_start` instead of `on_sanity_check_start`
  • Loading branch information
schwobr authored and williamFalcon committed Dec 7, 2019
1 parent 1051c18 commit 2f01c03
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 2 deletions.
16 changes: 16 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,26 @@ class ModelHooks(torch.nn.Module):
def on_sanity_check_start(self):
"""
Called before starting evaluate
.. warning:: will be deprecated.
:return:
"""
pass

def on_train_start(self):
"""Called at the beginning of training before sanity check
:return:
"""
# do something at the start of training
pass

def on_train_end(self):
"""
Called at the end of training before logger experiment is closed
:return:
"""
# do something at the end of training
pass

def on_batch_start(self, batch):
"""Called in the training loop before anything happens for that batch.
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def run_pretrain_routine(self, model):
# run tiny validation (if validation defined)
# to make sure program won't crash during val
ref_model.on_sanity_check_start()
ref_model.on_train_start()
if self.get_val_dataloaders() is not None and self.num_sanity_val_steps > 0:
# init progress bars for validation sanity check
pbar = tqdm.tqdm(desc='Validation sanity check',
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,8 @@ def train(self):

self.main_progress_bar.close()

model.on_train_end()

if self.logger is not None:
self.logger.finalize("success")

Expand Down
4 changes: 2 additions & 2 deletions tests/test_restore_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def assert_good_acc():

# new model
model = LightningTestModel(hparams)
model.on_sanity_check_start = assert_good_acc
model.on_train_start = assert_good_acc

# fit new model which should load hpc weights
new_trainer.fit(model)
Expand Down Expand Up @@ -311,7 +311,7 @@ def assert_good_acc():
for dataloader in trainer.get_val_dataloaders():
tutils.run_prediction(dataloader, trainer.model)

model.on_sanity_check_start = assert_good_acc
model.on_train_start = assert_good_acc

# by calling fit again, we trigger training, loading weights from the cluster
# and our hook to predict using current model before any more weight updates
Expand Down

0 comments on commit 2f01c03

Please sign in to comment.