Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-design call_hook interface #10575

Merged
merged 43 commits into from
Dec 4, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
dc8e838
first draft
daniellepintz Nov 16, 2021
3667a4d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 16, 2021
36078f2
doc fix
daniellepintz Nov 16, 2021
09949ee
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Nov 22, 2021
eabeba4
separate call_hooks
daniellepintz Nov 22, 2021
ec445a9
update call_hook refs
daniellepintz Nov 22, 2021
0c59dd8
fix more refs
daniellepintz Nov 22, 2021
6513caa
fix log
daniellepintz Nov 22, 2021
28701a0
cover edge case hooks
daniellepintz Nov 24, 2021
d3bdb46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2021
319ec5f
small fix
daniellepintz Nov 24, 2021
6ba14e4
Merge branch 'call_hook' of github.com:daniellepintz/pytorch-lightnin…
daniellepintz Nov 24, 2021
fda4d49
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Nov 24, 2021
088c441
only profile hook_name
daniellepintz Nov 24, 2021
800a11d
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Nov 24, 2021
460028a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2021
282a86b
flake8
daniellepintz Nov 24, 2021
14163ea
Merge branch 'call_hook' of github.com:daniellepintz/pytorch-lightnin…
daniellepintz Nov 24, 2021
fe2a76f
Fix failing tests. A ttp call was missed
carmocca Nov 25, 2021
d64d524
address comments
daniellepintz Nov 25, 2021
2617675
Merge branch 'call_hook' of github.com:daniellepintz/pytorch-lightnin…
daniellepintz Nov 25, 2021
a5fe96d
fix
daniellepintz Nov 25, 2021
63323bf
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Nov 26, 2021
a3d14ed
fix mypy
daniellepintz Nov 26, 2021
6eeec86
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Nov 26, 2021
53d6b53
remove setting of _current_fx_name
daniellepintz Nov 29, 2021
e5a3e3a
fix flake8
daniellepintz Nov 29, 2021
d4f66ae
fix
daniellepintz Nov 29, 2021
41dfc44
add asserts and optimizations
daniellepintz Dec 2, 2021
848a511
addr comments
daniellepintz Dec 2, 2021
d0b59c1
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Dec 3, 2021
ba89d1c
fix hook not callable
daniellepintz Dec 3, 2021
13e7b5d
fix ttp trainer ref
daniellepintz Dec 3, 2021
890e2e7
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Dec 3, 2021
9261c99
fix bad merge
daniellepintz Dec 3, 2021
3c5fe0c
fix on_train_batch_start
daniellepintz Dec 3, 2021
813c24a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 3, 2021
c48cdb2
fix
daniellepintz Dec 3, 2021
7b4f005
fix
daniellepintz Dec 3, 2021
7146f26
fix broken test
daniellepintz Dec 3, 2021
a163136
fix test
daniellepintz Dec 4, 2021
7e8ed03
fix _call_callback_hooks
daniellepintz Dec 4, 2021
492fc62
TypeError and other fix
daniellepintz Dec 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pl_examples/loop_examples/yielding_training_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def _training_step(self, generator):
training_step_output = next(generator)
self.trainer.accelerator.post_training_step()

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
training_step_output = self.trainer._call_hook(
self.trainer.lightning_module, "training_step_end", training_step_output
)

# The closure result takes care of properly detaching the loss for logging and peforms
# some additional checks that the output format is correct.
Expand Down
30 changes: 19 additions & 11 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,18 @@ def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
self._results.to(device=self.trainer.lightning_module.device)

if self.trainer.testing:
self.trainer.call_hook("on_test_start", *args, **kwargs)
self.trainer._call_hook(self.trainer, "on_test_start", *args, **kwargs)
self.trainer._call_hook(self.trainer.lightning_module, "on_test_start", *args, **kwargs)
else:
self.trainer.call_hook("on_validation_start", *args, **kwargs)
self.trainer._call_hook(self.trainer, "on_validation_start", *args, **kwargs)
self.trainer._call_hook(self.trainer.lightning_module, "on_validation_start", *args, **kwargs)

def _on_evaluation_model_eval(self) -> None:
"""Sets model to eval mode."""
if self.trainer.testing:
self.trainer.call_hook("on_test_model_eval")
self.trainer._call_hook(self.trainer.lightning_module, "on_test_model_eval")
else:
self.trainer.call_hook("on_validation_model_eval")
self.trainer._call_hook(self.trainer.lightning_module, "on_validation_model_eval")

def _on_evaluation_model_train(self) -> None:
"""Sets model to train mode."""
Expand All @@ -196,22 +198,26 @@ def _on_evaluation_model_train(self) -> None:
def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_end`` hook."""
if self.trainer.testing:
self.trainer.call_hook("on_test_end", *args, **kwargs)
self.trainer._call_hook(self.trainer, "on_test_end", *args, **kwargs)
self.trainer._call_hook(self.trainer.lightning_module, "on_test_end", *args, **kwargs)
else:
self.trainer.call_hook("on_validation_end", *args, **kwargs)
self.trainer._call_hook(self.trainer.lightning_module, "on_validation_end", *args, **kwargs)

# reset the logger connector state
self.trainer.logger_connector.reset_results()

def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks."""
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start", *args, **kwargs)
self.trainer._call_hook(self.trainer, "on_epoch_start", *args, **kwargs)
self.trainer._call_hook(self.trainer.lightning_module, "on_epoch_start", *args, **kwargs)

if self.trainer.testing:
self.trainer.call_hook("on_test_epoch_start", *args, **kwargs)
self.trainer._call_hook(self.trainer, "on_test_epoch_start", *args, **kwargs)
self.trainer._call_hook(self.trainer.lightning_module, "on_test_epoch_start", *args, **kwargs)
else:
self.trainer.call_hook("on_validation_epoch_start", *args, **kwargs)
self.trainer._call_hook(self.trainer, "on_validation_epoch_start", *args, **kwargs)
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
self.trainer._call_hook(self.trainer.lightning_module, "on_validation_epoch_start", *args, **kwargs)

def _evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
"""Runs ``{validation/test}_epoch_end``"""
Expand All @@ -237,6 +243,8 @@ def _evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
def _on_evaluation_epoch_end(self) -> None:
"""Runs ``on_{validation/test}_epoch_end`` hook."""
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"
self.trainer.call_hook(hook_name)
self.trainer.call_hook("on_epoch_end")
self.trainer._call_hook(self.trainer, hook_name)
self.trainer._call_hook(self.trainer.lightning_module, hook_name)
self.trainer._call_hook(self.trainer, "on_epoch_end")
self.trainer._call_hook(self.trainer.lightning_module, "on_epoch_end")
self.trainer.logger_connector.on_epoch_end()
12 changes: 8 additions & 4 deletions pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ def _on_predict_start(self) -> None:
self.trainer.lightning_module.zero_grad()

# hook
self.trainer.call_hook("on_predict_start")
self.trainer.call_hook("on_predict_epoch_start")
self.trainer._call_hook(self.trainer, "on_predict_start")
self.trainer._call_hook(self.trainer.lightning_module, "on_predict_start")
self.trainer._call_hook(self.trainer, "on_predict_epoch_start")
self.trainer._call_hook(self.trainer.lightning_module, "on_predict_epoch_start")

def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
"""Calls ``on_predict_epoch_end`` hook.
Expand All @@ -121,7 +123,8 @@ def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
"""
results = self.predictions

self.trainer.call_hook("on_predict_epoch_end", results)
self.trainer._call_hook(self.trainer, "on_predict_epoch_end", results)
self.trainer._call_hook(self.trainer.lightning_module, "on_predict_epoch_end", results)

if self.return_predictions:
return results[0] if self.num_dataloaders == 1 else results
Expand All @@ -133,7 +136,8 @@ def _on_predict_end(self) -> None:
self.epoch_batch_indices = []

# hook
self.trainer.call_hook("on_predict_end")
self.trainer._call_hook(self.trainer, "on_predict_end")
self.trainer._call_hook(self.trainer.lightning_module, "on_predict_end")

def _on_predict_model_eval(self):
"""Calls ``on_predict_model_eval`` hook."""
Expand Down
15 changes: 11 additions & 4 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> O
def _evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
"""Calls the `{validation/test}_step_end` hook."""
hook_name = "test_step_end" if self.trainer.testing else "validation_step_end"
output = self.trainer.call_hook(hook_name, *args, **kwargs)
output = self.trainer._call_hook(self.trainer.lightning_module, hook_name, *args, **kwargs)
return output

def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
Expand All @@ -239,9 +239,15 @@ def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
self.trainer.logger_connector.on_evaluation_batch_start(dataloader_idx, self._num_dataloaders)

if self.trainer.testing:
self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx)
self.trainer._call_hook(self.trainer, "on_test_batch_start", batch, batch_idx, dataloader_idx)
self.trainer._call_hook(
self.trainer.lightning_module, "on_test_batch_start", batch, batch_idx, dataloader_idx
)
else:
self.trainer.call_hook("on_validation_batch_start", batch, batch_idx, dataloader_idx)
self.trainer._call_hook(self.trainer, "on_validation_batch_start", batch, batch_idx, dataloader_idx)
self.trainer._call_hook(
self.trainer.lightning_module, "on_validation_batch_start", batch, batch_idx, dataloader_idx
)

def _on_evaluation_batch_end(
self, output: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
Expand All @@ -255,7 +261,8 @@ def _on_evaluation_batch_end(
dataloader_idx: Index of the dataloader producing the current batch
"""
hook_name = "on_test_batch_end" if self.trainer.testing else "on_validation_batch_end"
self.trainer.call_hook(hook_name, output, batch, batch_idx, dataloader_idx)
self.trainer._call_hook(self.trainer, hook_name, output, batch, batch_idx, dataloader_idx)
self.trainer._call_hook(self.trainer.lightning_module, hook_name, output, batch, batch_idx, dataloader_idx)

self.trainer.logger_connector.on_batch_end()

Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,10 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None

model_ref = self.trainer.lightning_module

self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx)
self.trainer._call_hook(self.trainer, "on_predict_batch_start", batch, batch_idx, dataloader_idx)
self.trainer._call_hook(
self.trainer.lightning_module, "on_predict_batch_start", batch, batch_idx, dataloader_idx
)

self.batch_progress.increment_started()
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -137,7 +140,10 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None
if predictions is None:
self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...")

self.trainer.call_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx)
self.trainer._call_hook(self.trainer, "on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx)
self.trainer._call_hook(
self.trainer.lightning_module, "on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx
)

self.batch_progress.increment_completed()

Expand Down
23 changes: 15 additions & 8 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,10 @@ def reset(self) -> None:
def on_run_start(self, data_fetcher: AbstractDataFetcher, **kwargs: Any) -> None:
# hook
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start")
self.trainer.call_hook("on_train_epoch_start")
self.trainer._call_hook(self.trainer, "on_epoch_start")
self.trainer._call_hook(self.trainer.lightning_module, "on_epoch_start")
self.trainer._call_hook(self.trainer, "on_train_epoch_start")
self.trainer._call_hook(self.trainer.lightning_module, "on_train_epoch_start")
self.trainer.fit_loop.epoch_progress.increment_started()

self._reload_dataloader_state_dict(data_fetcher)
Expand Down Expand Up @@ -170,7 +172,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
batch_output = []
else:
# hook
response = self.trainer.call_hook("on_batch_start")
response = self.trainer._call_hook(self.trainer, "on_batch_start")
if response == -1:
self.batch_progress.increment_processed()
raise StopIteration
Expand All @@ -184,7 +186,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
)

# hook
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
response = self.trainer._call_hook(self.trainer, "on_train_batch_start", batch, batch_idx, **extra_kwargs)
if response == -1:
self.batch_progress.increment_processed()
raise StopIteration
Expand Down Expand Up @@ -217,8 +219,11 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
else {}
)
self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs)
self.trainer.call_hook("on_batch_end")
self.trainer._call_hook(self.trainer, "on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs)
self.trainer._call_hook(
self.trainer.lightning_module, "on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs
)
self.trainer._call_hook(self.trainer, "on_batch_end")
self.trainer.logger_connector.on_batch_end()

self.batch_progress.increment_completed()
Expand Down Expand Up @@ -299,8 +304,10 @@ def on_run_end(self) -> None:
self.trainer.fit_loop.epoch_progress.increment_processed()

# call train epoch end hooks
self.trainer.call_hook("on_train_epoch_end")
self.trainer.call_hook("on_epoch_end")
self.trainer._call_hook(self.trainer, "on_train_epoch_end")
self.trainer._call_hook(self.trainer.lightning_module, "on_train_epoch_end")
self.trainer._call_hook(self.trainer, "on_epoch_end")
self.trainer._call_hook(self.trainer.lightning_module, "on_epoch_end")
self.trainer.logger_connector.on_epoch_end()

if self._num_ready_batches_reached():
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def on_run_start(self) -> None:
self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)
self._is_fresh_start_epoch = True
self._results.to(device=self.trainer.lightning_module.device)
self.trainer.call_hook("on_train_start")
self.trainer._call_hook(self.trainer, "on_train_start")
self.trainer._call_hook(self.trainer.lightning_module, "on_train_start")
self.trainer._call_hook(self.trainer.accelerator, "on_train_start")

def on_advance_start(self) -> None:
"""Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and
Expand Down Expand Up @@ -254,7 +256,8 @@ def on_run_end(self) -> None:
self.current_epoch = max(self.current_epoch - 1, 0)

# hook
self.trainer.call_hook("on_train_end")
self.trainer._call_hook(self.trainer, "on_train_end")
self.trainer._call_hook(self.trainer.lightning_module, "on_train_end")

# give accelerators a chance to finish
self.trainer.training_type_plugin.on_train_end()
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/loops/optimization/manual_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]

del step_kwargs

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
training_step_output = self.trainer._call_hook(
self.trainer.lightning_module, "training_step_end", training_step_output
)

self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)

Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
optimizer: the current optimizer
"""
self.optim_progress.optimizer.zero_grad.increment_ready()
self.trainer.call_hook("on_before_zero_grad", optimizer)
self.trainer._call_hook(self.trainer, "on_before_zero_grad", optimizer)
self.trainer._call_hook(self.trainer.lightning_module, "on_before_zero_grad", optimizer)
self.optim_progress.optimizer.zero_grad.increment_started()

def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
Expand Down Expand Up @@ -434,7 +435,9 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos

del step_kwargs

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
training_step_output = self.trainer._call_hook(
self.trainer.lightning_module, "training_step_end", training_step_output
)

self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)

Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def pre_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Ten
model: the model to be optimized
closure_loss: the loss value obtained from the closure
"""
model.trainer.call_hook("on_before_backward", closure_loss)
model.trainer._call_hook(model.trainer, "on_before_backward", closure_loss)
model.trainer._call_hook(model, "on_before_backward", closure_loss)
return closure_loss

def backward(
Expand Down Expand Up @@ -88,7 +89,8 @@ def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Te
"""
# once backward has been applied, release graph
closure_loss = closure_loss.detach()
model.trainer.call_hook("on_after_backward")
model.trainer._call_hook(model.trainer, "on_after_backward")
model.trainer._call_hook(model, "on_after_backward")
return closure_loss

def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None:
Expand All @@ -107,7 +109,8 @@ def _after_closure(
return
trainer = model.trainer
assert trainer is not None
trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
trainer._call_hook(trainer, "on_before_optimizer_step", optimizer, optimizer_idx)
trainer._call_hook(model, "on_before_optimizer_step", optimizer, optimizer_idx)
# TODO: this is done for the entire model but should be changed to per-optimizer
if optimizer_idx == 0:
self._track_grad_norm(trainer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _attach_model_callbacks(self) -> None:
In addition, all :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks
will be pushed to the end of the list, ensuring they run last.
"""
model_callbacks = self.trainer.call_hook("configure_callbacks")
model_callbacks = self.trainer._call_hook(self.trainer.lightning_module, "configure_callbacks")
if not model_callbacks:
return
model_callback_types = {type(c) for c in model_callbacks}
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def prepare_data(self) -> None:
" Move `prepare_data_per_node` setting to LightningModule property."
)
if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero):
self.trainer.call_hook("prepare_data")
self.trainer._call_hook(self.trainer.lightning_module, "prepare_data")
self.trainer._is_data_prepared = True

def attach_data(
Expand Down Expand Up @@ -293,7 +293,7 @@ def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
return self.instance

if isinstance(self.instance, LightningModule):
return self.instance.trainer.call_hook(self.name, pl_module=self.instance)
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
return self.instance.trainer._call_hook(self.instance, self.name, pl_module=self.instance)

if isinstance(self.instance, LightningDataModule):
method = getattr(self.instance, self.name)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def request_dataloader(
source = getattr(self._data_connector, f"_{stage.dataloader_prefix}_dataloader_source")

hook = f"{stage.dataloader_prefix}_dataloader"
self.call_hook("on_" + hook, pl_module=model)
self._call_hook(model, "on_" + hook, pl_module=model)
dataloader = source.dataloader()
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TrainerOptimizersMixin(ABC):
def init_optimizers(self, model: Optional["pl.LightningModule"]) -> Tuple[List, List, List]:
pl_module = self.lightning_module or model
self._lightning_optimizers = None
optim_conf = self.call_hook("configure_optimizers", pl_module=pl_module)
optim_conf = self._call_hook(model, "configure_optimizers", pl_module=pl_module)
if optim_conf is None:
rank_zero_warn(
"`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
Expand Down
Loading