Skip to content

Commit

Permalink
fix setup and on fit calls (#2252)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Jun 19, 2020
1 parent b7fc092 commit b5a2f1e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/model_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

class TrainerModelHooksMixin(ABC):

def is_function_implemented(self, f_name):
model = self.get_model()
def is_function_implemented(self, f_name, model=None):
if model is None:
model = self.get_model()
f_op = getattr(model, f_name, None)
return callable(f_op)

Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ def fit(

# callbacks
self.on_fit_start()
if self.is_function_implemented('on_fit_start'):
if self.is_function_implemented('on_fit_start', model):
model.on_fit_start()

# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
Expand All @@ -860,7 +860,7 @@ def fit(
self.barrier('fit_prepare_data')

self.setup('fit')
if self.is_function_implemented('setup'):
if self.is_function_implemented('setup', model):
model.setup('fit')

# Run auto batch size scaling
Expand Down Expand Up @@ -1149,8 +1149,8 @@ def test(
trainer.test(model, test_dataloaders=test)
"""
self.setup('test')
if self.is_function_implemented('setup'):
model_ref = self.model if model is None else model
model_ref = self.model if model is None else model
if self.is_function_implemented('setup', model_ref):
model_ref.setup('test')

self.barrier('test_setup')
Expand Down

0 comments on commit b5a2f1e

Please sign in to comment.