Skip to content

Commit

Permalink
Allow training type plugin to delay optimizer creation (FSDP 2/n) (#6331
Browse files Browse the repository at this point in the history
)

* Allow training_type_plugin to delay optimizer configure

* Add missing references to trainer, add a CPU accelerator based test
  • Loading branch information
SeanNaren authored Mar 22, 2021
1 parent 853523e commit 58c9fa7
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 6 deletions.
9 changes: 6 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
model: the LightningModule
"""
self.setup_training_type_plugin(self.training_type_plugin, model)
self.setup_optimizers(trainer)
if not self.training_type_plugin.setup_optimizers_in_pre_dispatch:
self.setup_optimizers(trainer)
self.setup_precision_plugin(self.precision_plugin)

def start_training(self, trainer: 'Trainer') -> None:
Expand All @@ -97,12 +98,14 @@ def start_evaluating(self, trainer: 'Trainer') -> None:
def start_predicting(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_predicting(trainer)

def pre_dispatch(self) -> None:
def pre_dispatch(self, trainer: 'Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.pre_dispatch()
if self.training_type_plugin.setup_optimizers_in_pre_dispatch:
self.setup_optimizers(trainer)
self.precision_plugin.pre_dispatch()

def post_dispatch(self) -> None:
def post_dispatch(self, trainer: 'Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch()
self.precision_plugin.post_dispatch()
Expand Down
10 changes: 10 additions & 0 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,13 @@ def init_optimizers(self, trainer: "Trainer", model: LightningModule):

def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
optimizer.step(closure=lambda_closure, **kwargs)

@property
def setup_optimizers_in_pre_dispatch(self) -> bool:
"""
Override to delay setting optimizers and schedulers till after dispatch.
This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model.
However this may break certain precision plugins such as APEX which require optimizers to be set.
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
"""
return False
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def fit(
return self.accelerator.results or 1

def pre_dispatch(self):
self.accelerator.pre_dispatch()
self.accelerator.pre_dispatch(self)

# log hyper-parameters
if self.logger is not None:
Expand All @@ -505,7 +505,7 @@ def pre_dispatch(self):
self.logger.save()

def post_dispatch(self):
self.accelerator.post_dispatch()
self.accelerator.post_dispatch(self)
self.accelerator.teardown()

def dispatch(self):
Expand Down
35 changes: 34 additions & 1 deletion tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.plugins import SingleDevicePlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel


def test_unsupported_precision_plugins():
Expand All @@ -18,3 +19,35 @@ def test_unsupported_precision_plugins():
)
with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."):
accelerator.setup(trainer=trainer, model=model)


@pytest.mark.parametrize("delay_dispatch", [True, False])
def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch):
"""
Test when using a custom training type plugin that delays setup optimizers,
we do not call setup optimizers till ``pre_dispatch``.
"""

class TestModel(BoringModel):
def on_fit_start(self):
if delay_dispatch:
# Ensure we haven't setup optimizers if we've delayed dispatch
assert len(self.trainer.optimizers) == 0
else:
assert len(self.trainer.optimizers) > 0

def on_fit_end(self):
assert len(self.trainer.optimizers) > 0

class CustomPlugin(SingleDevicePlugin):
@property
def setup_optimizers_in_pre_dispatch(self) -> bool:
return delay_dispatch

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
plugins=CustomPlugin(device=torch.device("cpu"))
)
trainer.fit(model)

0 comments on commit 58c9fa7

Please sign in to comment.