Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4/n Move Accelerator into strategy - remove X_step() from accelerator #10890

Merged
merged 11 commits into from
Dec 6, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed methods `pre_dispatch`, `dispatch` and `post_dispatch` from the `Accelerator` ([#10885](https://github.com/PyTorchLightning/pytorch-lightning/pull/10885))


- Removed method `training_step`, `test_step`, `validation_step` and `predict_step` from the `Accelerator` ([#10890](https://github.com/PyTorchLightning/pytorch-lightning/pull/10890))


### Fixed

- Fixed an issue with `SignalConnector` not restoring the default signal handlers on teardown when running on SLURM or with fault-tolerant training enabled ([#10611](https://github.com/PyTorchLightning/pytorch-lightning/pull/10611))
Expand Down
2 changes: 1 addition & 1 deletion pl_examples/loop_examples/yielding_training_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _get_generator(self, split_batch, batch_idx, opt_idx):
# Here we are basically calling `lightning_module.training_step()`
# and this returns a generator! The `training_step` is handled by the
# accelerator to enable distributed training.
return self.trainer.accelerator.training_step(*step_kwargs.values())
return self.trainer.training_type_plugin.training_step(*step_kwargs.values())

def _training_step(self, generator):
# required for logging
Expand Down
33 changes: 0 additions & 33 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.utilities.types import STEP_OUTPUT


class Accelerator:
Expand Down Expand Up @@ -103,38 +102,6 @@ def teardown(self) -> None:
"""
self.training_type_plugin.teardown()

def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
"""The actual training step.

See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details
"""
with self.training_type_plugin.precision_plugin.train_step_context():
return self.training_type_plugin.training_step(*args, **kwargs)

def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
"""The actual validation step.

See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details
"""
with self.training_type_plugin.precision_plugin.val_step_context():
return self.training_type_plugin.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
"""The actual test step.

See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details
"""
with self.training_type_plugin.precision_plugin.test_step_context():
return self.training_type_plugin.test_step(*args, **kwargs)

def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
"""The actual predict step.

See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details
"""
with self.training_type_plugin.precision_plugin.predict_step_context():
return self.training_type_plugin.predict_step(*args, **kwargs)

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""Gets stats for a given device.

Expand Down
26 changes: 15 additions & 11 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,22 +399,26 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp,
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor

def training_step(self, *args, **kwargs) -> Optional[Any]:
return self.model(*args, **kwargs)
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
with self.precision_plugin.train_step_context():
return self.model(*args, **kwargs)

def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
if isinstance(self.model, DistributedDataParallel):
# used when calling `trainer.fit`
return self.model(*args, **kwargs)
else:
# used when calling `trainer.validate`
return self.lightning_module.validation_step(*args, **kwargs)
with self.precision_plugin.val_step_context():
if isinstance(self.model, DistributedDataParallel):
# used when calling `trainer.fit`
return self.model(*args, **kwargs)
else:
# used when calling `trainer.validate`
return self.lightning_module.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
return self.lightning_module.test_step(*args, **kwargs)
with self.precision_plugin.test_step_context():
return self.lightning_module.test_step(*args, **kwargs)

def predict_step(self, *args, **kwargs) -> Any:
return self.lightning_module.predict_step(*args, **kwargs)
def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
with self.precision_plugin.predict_step_context():
return self.lightning_module.predict_step(*args, **kwargs)

def post_training_step(self):
if not self.lightning_module.automatic_optimization:
Expand Down
26 changes: 15 additions & 11 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,22 +335,26 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp,
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor

def training_step(self, *args, **kwargs) -> Optional[Any]:
return self.model(*args, **kwargs)
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
with self.precision_plugin.train_step_context():
return self.model(*args, **kwargs)

def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
if isinstance(self.model, DistributedDataParallel):
# used when calling `trainer.fit`
return self.model(*args, **kwargs)
else:
# used when calling `trainer.validate`
return self.lightning_module.validation_step(*args, **kwargs)
with self.precision_plugin.val_step_context():
if isinstance(self.model, DistributedDataParallel):
# used when calling `trainer.fit`
return self.model(*args, **kwargs)
else:
# used when calling `trainer.validate`
return self.lightning_module.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
return self.lightning_module.test_step(*args, **kwargs)
with self.precision_plugin.test_step_context():
return self.lightning_module.test_step(*args, **kwargs)

def predict_step(self, *args, **kwargs) -> Any:
return self.lightning_module.predict_step(*args, **kwargs)
def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
with self.precision_plugin.predict_step_context():
return self.lightning_module.predict_step(*args, **kwargs)

def post_training_step(self):
if not self.lightning_module.automatic_optimization:
Expand Down
17 changes: 10 additions & 7 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache

warning_cache = WarningCache()
Expand Down Expand Up @@ -879,11 +879,14 @@ def checkpoint_io(self) -> CheckpointIO:
def checkpoint_io(self, plugin: CheckpointIO) -> None:
raise MisconfigurationException("DeepSpeed currently does not support custom checkpoint plugins.")

def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.val_step_context():
return self.model(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.test_step_context():
return self.model(*args, **kwargs)

def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
with self.precision_plugin.predict_step_context():
return self.model(*args, **kwargs)
22 changes: 13 additions & 9 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _METRIC_COLLECTION
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, STEP_OUTPUT


class DataParallelPlugin(ParallelPlugin):
Expand Down Expand Up @@ -118,17 +118,21 @@ def broadcast(self, obj: object, src: int = 0) -> object:
def reduce_boolean_decision(self, decision: bool) -> bool:
return decision

def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
with self.precision_plugin.train_step_context():
return self.model(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.val_step_context():
return self.model(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.test_step_context():
return self.model(*args, **kwargs)

def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
with self.precision_plugin.predict_step_context():
return self.model(*args, **kwargs)

def training_step_end(self, output):
if not is_overridden("training_step_end", self.lightning_module):
Expand Down
21 changes: 13 additions & 8 deletions pytorch_lightning/plugins/training_type/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
from fairscale.nn import default_auto_wrap_policy, enable_wrap
Expand Down Expand Up @@ -172,17 +173,21 @@ def model_to_device(self) -> None:
# ensure we update the device type in the lightning module
self.lightning_module.to(self.root_device)

def training_step(self, *args, **kwargs):
return self.model.training_step(*args, **kwargs)
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
with self.precision_plugin.train_step_context():
return self.model.training_step(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model.validation_step(*args, **kwargs)
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.val_step_context():
return self.model.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.model.test_step(*args, **kwargs)
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.test_step_context():
return self.model.test_step(*args, **kwargs)

def predict_step(self, *args, **kwargs):
return self.model.predict_step(*args, **kwargs)
def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
with self.precision_plugin.predict_step_context():
return self.model.predict_step(*args, **kwargs)

def post_training_step(self):
pass
Expand Down
21 changes: 13 additions & 8 deletions pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pytorch_lightning.utilities.data import _get_dataloader_init_kwargs
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _POPTORCH_AVAILABLE:
import poptorch
Expand Down Expand Up @@ -258,17 +259,21 @@ def _step(self, stage: RunningStage, *args: Any, **kwargs: Any):
self.lightning_module._running_torchscript = False
return out

def training_step(self, *args, **kwargs):
return self._step(RunningStage.TRAINING, *args, **kwargs)
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
with self.precision_plugin.train_step_context():
return self._step(RunningStage.TRAINING, *args, **kwargs)

def validation_step(self, *args, **kwargs):
return self._step(RunningStage.VALIDATING, *args, **kwargs)
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.val_step_context():
return self._step(RunningStage.VALIDATING, *args, **kwargs)

def test_step(self, *args, **kwargs):
return self._step(RunningStage.TESTING, *args, **kwargs)
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.test_step_context():
return self._step(RunningStage.TESTING, *args, **kwargs)

def predict_step(self, *args, **kwargs):
return self._step(RunningStage.PREDICTING, *args, **kwargs)
def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
with self.precision_plugin.predict_step_context():
return self._step(RunningStage.PREDICTING, *args, **kwargs)

def teardown(self) -> None:
# undo dataloader patching
Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,17 +302,17 @@ def start_predicting(self, trainer: "pl.Trainer") -> Any:
self._clean_logger(trainer)
return super().start_predicting(trainer)

def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.val_step_context():
return self.model(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.test_step_context():
return self.model(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
with self.precision_plugin.predict_step_context():
return self.model(*args, **kwargs)

def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
self._pod_progress_bar_force_stdout()
Expand Down
38 changes: 29 additions & 9 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.types import _PATH
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT

TBroadcast = TypeVar("TBroadcast")

Expand Down Expand Up @@ -313,20 +313,40 @@ def start_predicting(self, trainer: "pl.Trainer") -> Any:
# double dispatch to initiate the predicting loop
return trainer.run_stage()

def training_step(self, *args, **kwargs):
return self.model.training_step(*args, **kwargs)
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
"""The actual training step.

See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details
"""
with self.precision_plugin.train_step_context():
return self.model.training_step(*args, **kwargs)

def post_training_step(self):
pass

def validation_step(self, *args, **kwargs):
return self.model.validation_step(*args, **kwargs)
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
"""The actual validation step.

See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details
"""
with self.precision_plugin.val_step_context():
return self.model.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
"""The actual test step.

See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details
"""
with self.precision_plugin.test_step_context():
return self.model.test_step(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.model.test_step(*args, **kwargs)
def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
"""The actual predict step.

def predict_step(self, *args, **kwargs):
return self.model.predict_step(*args, **kwargs)
See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details
"""
with self.precision_plugin.predict_step_context():
return self.model.predict_step(*args, **kwargs)

def training_step_end(self, output):
return output
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,7 +1538,7 @@ def _call_accelerator_hook(
**kwargs: Any,
) -> Optional[Any]:
self.lightning_module._current_fx_name = hook_name
fn = getattr(self.accelerator, hook_name)
fn = getattr(self.training_type_plugin, hook_name)
four4fish marked this conversation as resolved.
Show resolved Hide resolved
if not callable(fn):
return None

Expand Down