From 692f77b8a756317356eda0f7a933dac0b0cd16ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 31 Jan 2021 12:08:16 +0100 Subject: [PATCH] Refactor LightningDataParallel (#5670) * module * fix model access * scalar conversion * refactor * kwargs * auto unsqueeze * refactor code duplication * clean up * docs * update dp docs * changelog * generalize test * test * rename * warning cache * isort * unsqueezing test * device * device * scalar test * device * device * include coverage of overrides * clear * add deprecation test * docs * improve coverage * increase coverage * fix merge * extend test * rename base class * mention the predict method in docs * combine iteration over collection * remove override * move * line * Apply suggestions from code review * fix running stage * f401 * fix cyclic import Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec --- CHANGELOG.md | 6 + .../accelerators/legacy/dp_accelerator.py | 10 +- pytorch_lightning/overrides/__init__.py | 2 + pytorch_lightning/overrides/base.py | 63 +++ pytorch_lightning/overrides/data_parallel.py | 382 +++--------------- pytorch_lightning/overrides/distributed.py | 77 ++++ .../plugins/legacy/ddp_plugin.py | 2 +- pytorch_lightning/utilities/warnings.py | 3 + setup.cfg | 1 - tests/deprecated_api/test_remove_1-4.py | 20 +- tests/models/test_restore.py | 2 +- tests/overrides/test_data_parallel.py | 130 +++++- 12 files changed, 349 insertions(+), 349 deletions(-) create mode 100644 pytorch_lightning/overrides/base.py create mode 100644 pytorch_lightning/overrides/distributed.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a041defcd0028..f82c6796e808e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -120,6 +120,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Moved accelerators and plugins to its `legacy` pkg ([#5645](https://github.com/PyTorchLightning/pytorch-lightning/pull/5645)) +- Deprecated `LightningDistributedDataParallel` in favor of new wrapper module `LightningDistributedModule` ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185)) + + +- Deprecated `LightningDataParallel` in favor of new wrapper module `LightningParallelModule` ([#5670](https://github.com/PyTorchLightning/pytorch-lightning/pull/5670)) + + ### Removed - Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321)) diff --git a/pytorch_lightning/accelerators/legacy/dp_accelerator.py b/pytorch_lightning/accelerators/legacy/dp_accelerator.py index cbacb82c80dc0..366a466ed0454 100644 --- a/pytorch_lightning/accelerators/legacy/dp_accelerator.py +++ b/pytorch_lightning/accelerators/legacy/dp_accelerator.py @@ -21,7 +21,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.step_result import Result from pytorch_lightning.distributed import LightningDistributed -from pytorch_lightning.overrides.data_parallel import LightningDataParallel +from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -74,7 +74,7 @@ def __init_torch_data_parallel(self, model): # set dp device torch.cuda.set_device(self.trainer.root_gpu) - model = LightningDataParallel(model, device_ids=device_ids) + model = torch.nn.DataParallel(LightningParallelModule(model), device_ids=device_ids) return model def __init_half_precision(self, model): @@ -181,8 +181,10 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list): scheduler.load_state_dict(state) def get_reference_model(self, model) -> LightningModule: - if isinstance(model, LightningDataParallel): - return model.module + if isinstance(model, torch.nn.DataParallel): + model = model.module + if isinstance(model, LightningParallelModule): + model = model.module return model @property diff --git a/pytorch_lightning/overrides/__init__.py b/pytorch_lightning/overrides/__init__.py index e69de29bb2d1d..ca97a63649389 100644 --- a/pytorch_lightning/overrides/__init__.py +++ b/pytorch_lightning/overrides/__init__.py @@ -0,0 +1,2 @@ +from pytorch_lightning.overrides.data_parallel import LightningParallelModule # noqa: F401 +from pytorch_lightning.overrides.distributed import LightningDistributedModule # noqa: F401 diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py new file mode 100644 index 0000000000000..b2ad5b7d710fe --- /dev/null +++ b/pytorch_lightning/overrides/base.py @@ -0,0 +1,63 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import torch + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.warnings import WarningCache + +warning_cache = WarningCache() + + +class _LightningModuleWrapperBase(torch.nn.Module): + + def __init__(self, pl_module: LightningModule): + """ + Wraps the user's LightningModule and redirects the forward call to the appropriate + method, either ``training_step``, ``validation_step`` or ``test_step``. + If the LightningModule is in none of the states `training`, `testing` or `validation`, + the inputs will be redirected to the + :meth:`~pytorch_lightning.core.lightning.LightningModule.predict` method. + Inheriting classes may also modify the inputs or outputs of forward. + + Args: + pl_module: the model to wrap + """ + super().__init__() + self.module = pl_module + + def forward(self, *inputs, **kwargs): + running_stage = self.module.running_stage + + if running_stage == RunningStage.TRAINING: + output = self.module.training_step(*inputs, **kwargs) + warn_if_output_is_none(output, "training_step") + elif running_stage == RunningStage.TESTING: + output = self.module.test_step(*inputs, **kwargs) + warn_if_output_is_none(output, "test_step") + elif running_stage == RunningStage.EVALUATING: + output = self.module.validation_step(*inputs, **kwargs) + warn_if_output_is_none(output, "validation_step") + else: + output = self.module.predict(*inputs, **kwargs) + + return output + + +def warn_if_output_is_none(output: Any, method_name: str) -> None: + """ Warns user about which method returned None. """ + if output is None: + warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?') diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 8bc70f03d329d..8d1710e471197 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -11,154 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import itertools -import threading +import numbers import warnings -from collections.abc import Iterable, Mapping -from itertools import chain -from typing import Any, Optional +from typing import Any import torch -from torch import Tensor -from torch.cuda._utils import _get_device_index -from torch.nn import DataParallel, Module +from torch.nn import DataParallel from torch.nn.parallel import DistributedDataParallel -from torch.nn.parallel._functions import Gather from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.core.step_result import Result -from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.utilities.warnings import WarningCache - - -def _find_tensors(obj): # pragma: no-cover - r""" - Recursively find all tensors contained in the specified object. - """ - if isinstance(obj, torch.Tensor): - return [obj] - if isinstance(obj, (list, tuple)): - return itertools.chain(*map(_find_tensors, obj)) - if isinstance(obj, dict): - return itertools.chain(*map(_find_tensors, obj.values())) - return [] - - -def get_a_var(obj): # pragma: no-cover - if isinstance(obj, torch.Tensor): - return obj - - if isinstance(obj, (list, tuple)): - for result in map(get_a_var, obj): - if isinstance(result, torch.Tensor): - return result - if isinstance(obj, dict): - for result in map(get_a_var, obj.items()): - if isinstance(result, torch.Tensor): - return result - return None - - -warning_cache = WarningCache() +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.overrides.distributed import LightningDistributedModule +from pytorch_lightning.utilities.apply_func import apply_to_collection class LightningDataParallel(DataParallel): - """ - Override the forward call in lightning so it goes to training and validation step respectively - """ - - def forward(self, *inputs, **kwargs): - if not self.device_ids: - return self.module(*inputs, **kwargs) - - for t in chain(self.module.parameters(), self.module.buffers()): - if t.device != self.src_device_obj: - raise RuntimeError( - f"module must have its parameters and buffers on device {self.src_device_obj} (device_ids[0])" - f" but found one of them on device: {t.device}" - ) - - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) - - if len(self.device_ids) == 1: - - running_stage = self.module.running_stage - - if running_stage == RunningStage.TRAINING: - return self.module.training_step(*inputs[0], **kwargs[0]) - - elif running_stage == RunningStage.TESTING: - return self.module.test_step(*inputs[0], **kwargs[0]) - - elif running_stage == RunningStage.EVALUATING: - return self.module.validation_step(*inputs[0], **kwargs[0]) - - else: - return self.module.predict(*inputs[0], **kwargs[0]) - replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) - outputs = self.parallel_apply(replicas, inputs, kwargs) - - if isinstance(outputs[0], Result): - outputs = self.__gather_structured_result(outputs) - else: - outputs = self.gather(outputs) - return outputs - - def __gather_structured_result(self, outputs): - prototype_output = outputs[0] - original_class = prototype_output.__class__ - outputs = [dict(x) for x in outputs] - - # remove all the meta info - meta = outputs[0]['meta'] - for i, output in enumerate(outputs): - del output['meta'] - - outputs = self.gather(outputs) - - result = original_class() - - result.update(outputs) - result['meta'] = meta - return result - - def gather(self, outputs): - r""" - Override the gather method to support python scalars as well. - """ - - def gather_map(outputs): - elem = outputs[0] - elem_type = type(elem) - - if isinstance(elem, torch.Tensor): - return Gather.apply(self.output_device, self.dim, *outputs) - - if elem is None: - return None - - if isinstance(elem, Mapping): - if not all((len(elem) == len(d) for d in outputs)): - raise ValueError('All dicts must have the same number of keys') - return elem_type(((k, gather_map([d[k] for d in outputs])) for k in elem)) - - if isinstance(elem, Iterable) and not isinstance(elem, str): - return elem_type(map(gather_map, zip(*outputs))) - - return outputs - - # Recursive function calls like this create reference cycles. - # Setting the function to None clears the refcycle. - try: - res = gather_map(outputs) - finally: - gather_map = None - return res - - def parallel_apply(self, replicas, inputs, kwargs): - return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) + def __init__(self, module: LightningModule, *args, **kwargs): + warnings.warn( + "The usage of `LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4." + " From now on we recommend to directly subclass `torch.nn.parallel.DataParallel`.", + DeprecationWarning + ) + super().__init__(LightningParallelModule(module), *args, **kwargs) class LightningDistributedDataParallel(DistributedDataParallel): @@ -166,209 +41,60 @@ class LightningDistributedDataParallel(DistributedDataParallel): def __init__(self, module: LightningModule, *args, **kwargs): warnings.warn( "The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4." - " From now on we recommend to directly sublcass `torch.nn.parallel.DistributedDataParallel`.", + " From now on we recommend to directly subclass `torch.nn.parallel.DistributedDataParallel`.", DeprecationWarning ) super().__init__(LightningDistributedModule(module), *args, **kwargs) -class LightningDistributedModule(torch.nn.Module): - - def __init__(self, pl_module: LightningModule): - """ - Wraps the user's LightningModule and redirects the forward call to the appropriate - method, either ``training_step``, ``validation_step`` or ```test_step``. - This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as - shown in the example. - - Example: - - ddp_model = DistributedDataParallel( - module=LightningDistributedModule(lightning_module), - device_ids=[local_rank], - ... - ) - - Args: - pl_module: the model to wrap - - """ - super().__init__() - self.module = pl_module - - def forward(self, *inputs, **kwargs): - - running_stage = self.module.running_stage - - if running_stage == RunningStage.TRAINING: - output = self.module.training_step(*inputs, **kwargs) - warn_if_output_is_none(output, "training_step") - - elif running_stage == RunningStage.TESTING: - output = self.module.test_step(*inputs, **kwargs) - warn_if_output_is_none(output, "test_step") - - elif running_stage == RunningStage.EVALUATING: - output = self.module.validation_step(*inputs, **kwargs) - warn_if_output_is_none(output, "validation_step") - - else: - output = self.module.predict(*inputs, **kwargs) - - return output - - -# In manual_optimization, we need to call reducer prepare_for_backward. -# Note: Keep track of Pytorch DDP and update if there is a change -# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638 -def prepare_for_backward(model: DistributedDataParallel, output: Any): - if torch.is_grad_enabled() and model.require_backward_grad_sync: - model.require_forward_param_sync = True - # We'll return the output object verbatim since it is a freeform - # object. We need to find any tensors in this object, though, - # because we need to figure out which parameters were used during - # this forward pass, to ensure we short circuit reduction for any - # unused parameters. Only if `find_unused_parameters` is set. - if model.find_unused_parameters: - model.reducer.prepare_for_backward(list(_find_tensors(output))) - else: - model.reducer.prepare_for_backward([]) - else: - model.require_forward_param_sync = False - - -def warn_if_output_is_none(output: Any, method_name: str) -> None: - if output is None: - warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?') - - -def warn_missing_output(fx_called): - if fx_called == 'training_step': - warning_cache.warn("Your training_step returned None. Make sure that was your intention!") - - -def parallel_apply( - modules: Module, - inputs: Tensor, - kwargs_tup: Optional[tuple] = None, - devices: Optional[list] = None, -): # pragma: no-cover - r"""Applies each `module` in :attr:`modules` in parallel on arguments - contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) - on each of :attr:`devices`. +class LightningParallelModule(_LightningModuleWrapperBase): + """ + Wraps the user's LightningModule and redirects the forward call to the appropriate + method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``. + This class is used in combination with :class:`~torch.nn.parallel.DataParallel` as + shown in the example. It also takes care of converting Python scalars to Tensors and + un-squeezes 0-dimensional Tensors as it is required by :class:`~torch.nn.parallel.DataParallel`. + + Example: + + dp_model = torch.nn.DataParallel( + module=LightningParallelModule(lightning_module), + device_ids=[3, 4], + ... + ) Args: - modules: modules to be parallelized - inputs: inputs to the modules - devices: CUDA devices + pl_module: the model to wrap - :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and - :attr:`devices` (if given) should all have same length. Moreover, each - element of :attr:`inputs` can either be a single object as the only argument - to a module, or a collection of positional arguments. """ - assert len(modules) == len(inputs) - if kwargs_tup is not None: - assert len(modules) == len(kwargs_tup) - else: - kwargs_tup = ({}, ) * len(modules) - if devices is not None: - assert len(modules) == len(devices) - else: - devices = [None] * len(modules) - devices = list(map(lambda x: _get_device_index(x, True), devices)) - lock = threading.Lock() - results = {} - grad_enabled = torch.is_grad_enabled() - - def _worker(i, module, input, kwargs, device=None): - torch.set_grad_enabled(grad_enabled) - if device is None: - device = get_a_var(input).get_device() - try: - with torch.cuda.device(device): - # this also avoids accidental slicing of `input` if it is a Tensor - if not isinstance(input, (list, tuple)): - input = (input, ) - - module = module.to(device) - - # --------------- - # CHANGE - if module.running_stage == RunningStage.TRAINING: - output = module.training_step(*input, **kwargs) - fx_called = 'training_step' - - elif module.running_stage == RunningStage.TESTING: - output = module.test_step(*input, **kwargs) - fx_called = 'test_step' - - elif module.running_stage == RunningStage.EVALUATING: - output = module.validation_step(*input, **kwargs) - fx_called = 'validation_step' - - else: - output = module.predict(*input, **kwargs) - fx_called = 'predict' - - if output is None: - warn_missing_output(fx_called) - - if output is not None and module._distrib_type in ('dp', 'ddp2'): - auto_squeeze_dim_zeros(output) - # --------------- - - with lock: - results[i] = output - # todo: specify the possible exception - except Exception as ex: - with lock: - results[i] = ex - - # TODO: fix hack (maybe not a hack) - # make sure each module knows what training state it's in... - # fixes weird bug where copies are out of sync - root_m = modules[0] - for m in modules[1:]: - m.training = root_m.training - m.testing = root_m.testing + def __init__(self, pl_module: LightningModule): + super().__init__(pl_module) - if len(modules) > 1: - threads = [ - threading.Thread(target=_worker, args=(i, module, input, kwargs, device)) - for i, (module, input, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices)) - ] + def forward(self, *inputs, **kwargs): + output = super().forward(*inputs, **kwargs) - for thread in threads: - thread.start() - for thread in threads: - thread.join() - else: - _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) + def output_transform(data: Any): + data = python_scalar_to_tensor(data, self.module.device) + data = unsqueeze_scalar_tensor(data) + return data - outputs = [] - for i in range(len(inputs)): - output = results[i] - if isinstance(output, Exception): - raise output - outputs.append(output) - return outputs + output = apply_to_collection( + output, + dtype=(numbers.Number, torch.Tensor), + function=output_transform, + ) + return output -def auto_squeeze_dim_zeros(output): - """ - In DP or DDP2 we need to unsqueeze dim 0 - :param output: - :return: - """ - if isinstance(output, torch.Tensor): - output = output.unsqueeze(0) - return output +def python_scalar_to_tensor(data: Any, device: torch.device = torch.device("cpu")) -> Any: + """ Converts a Python scalar number to a torch tensor and places it on the given device. """ + if isinstance(data, numbers.Number): + data = torch.tensor([data], device=device) + return data - for k, v in output.items(): - if not isinstance(v, torch.Tensor): - continue - is_scalar = v.dim() == 0 - if is_scalar: - output[k] = output[k].unsqueeze(0) +def unsqueeze_scalar_tensor(data: Any) -> Any: + """ Un-squeezes a 0-dim tensor. """ + if isinstance(data, torch.Tensor) and data.dim() == 0: + data = data.unsqueeze(0) + return data diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py new file mode 100644 index 0000000000000..c934e422a4308 --- /dev/null +++ b/pytorch_lightning/overrides/distributed.py @@ -0,0 +1,77 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +from typing import Any + +import torch +from torch.nn.parallel import DistributedDataParallel + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase + + +class LightningDistributedModule(_LightningModuleWrapperBase): + + def __init__(self, pl_module: LightningModule): + """ + Wraps the user's LightningModule and redirects the forward call to the appropriate + method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``. + This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as + shown in the example. + + Example: + + ddp_model = torch.nn.parallel.DistributedDataParallel( + module=LightningDistributedModule(lightning_module), + device_ids=[local_rank], + ... + ) + + Args: + pl_module: the model to wrap + + """ + super().__init__(pl_module) + + +def _find_tensors(obj): # pragma: no-cover + r""" + Recursively find all tensors contained in the specified object. + """ + if isinstance(obj, torch.Tensor): + return [obj] + if isinstance(obj, (list, tuple)): + return itertools.chain(*map(_find_tensors, obj)) + if isinstance(obj, dict): + return itertools.chain(*map(_find_tensors, obj.values())) + return [] + + +# In manual_optimization, we need to call reducer prepare_for_backward. +# Note: Keep track of Pytorch DDP and update if there is a change +# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638 +def prepare_for_backward(model: DistributedDataParallel, output: Any): + if torch.is_grad_enabled() and model.require_backward_grad_sync: + model.require_forward_param_sync = True + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if model.find_unused_parameters: + model.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + model.reducer.prepare_for_backward([]) + else: + model.require_forward_param_sync = False diff --git a/pytorch_lightning/plugins/legacy/ddp_plugin.py b/pytorch_lightning/plugins/legacy/ddp_plugin.py index 8da0c34dedfdf..4d7303dd7035f 100644 --- a/pytorch_lightning/plugins/legacy/ddp_plugin.py +++ b/pytorch_lightning/plugins/legacy/ddp_plugin.py @@ -21,7 +21,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.overrides.data_parallel import LightningDistributedModule, prepare_for_backward +from pytorch_lightning.overrides.distributed import LightningDistributedModule, prepare_for_backward from pytorch_lightning.plugins.legacy.plugin import LightningPlugin from pytorch_lightning.utilities import DeviceType diff --git a/pytorch_lightning/utilities/warnings.py b/pytorch_lightning/utilities/warnings.py index a5d5be95ad76f..6727025245a1b 100644 --- a/pytorch_lightning/utilities/warnings.py +++ b/pytorch_lightning/utilities/warnings.py @@ -23,3 +23,6 @@ def warn(self, m): if m not in self.warnings: self.warnings.add(m) rank_zero_warn(m) + + def clear(self): + self.warnings.clear() diff --git a/setup.cfg b/setup.cfg index dc351647a9986..31046d3d8b30c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,7 +49,6 @@ omit = pytorch_lightning/accelerators/dp_*.py pytorch_lightning/accelerators/tpu_*.py pytorch_lightning/cluster_environments/*.py - pytorch_lightning/overrides/data_parallel.py pytorch_lightning/utilities/xla_device_utils.py pytorch_lightning/utilities/distributed.py pytorch_lightning/tuner/auto_gpu_select.py diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index a4120d77676b5..27af0003beb43 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -18,7 +18,12 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel +from pytorch_lightning.overrides.data_parallel import ( + LightningDataParallel, + LightningDistributedDataParallel, + LightningParallelModule, +) +from pytorch_lightning.overrides.distributed import LightningDistributedModule from pytorch_lightning.plugins.legacy.ddp_plugin import DDPPlugin from tests.base import BoringModel from tests.deprecated_api import _soft_unimport_module @@ -165,6 +170,8 @@ def configure_ddp(self, model, device_ids): device_ids=device_ids, **self._ddp_kwargs, ) + assert isinstance(model, torch.nn.parallel.DistributedDataParallel) + assert isinstance(model.module, LightningDistributedModule) return model @@ -180,3 +187,14 @@ def test_v1_4_0_deprecated_lightning_distributed_data_parallel(tmpdir): plugins=[CustomDDPPlugin()] ) trainer.fit(model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_v1_4_0_deprecated_lightning_data_parallel(): + model = BoringModel() + with pytest.deprecated_call( + match="`LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4." + ): + dp_model = LightningDataParallel(model, device_ids=[0]) + assert isinstance(dp_model, torch.nn.DataParallel) + assert isinstance(dp_model.module, LightningParallelModule) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index f34c0e196bf85..44eb0f679f13c 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -393,7 +393,7 @@ def assert_good_acc(): # haven't trained with the new loaded model dp_model = new_trainer.model dp_model.eval() - dp_model.module.running_stage = RunningStage.EVALUATING + dp_model.module.module.running_stage = RunningStage.EVALUATING dataloader = trainer.train_dataloader tpipes.run_prediction(dp_model, dataloader, dp=True) diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index e61b81fd8488e..8a98d51bd58cc 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -2,36 +2,57 @@ import pytest import torch +from torch.nn import DataParallel -from pytorch_lightning.overrides.data_parallel import LightningDistributedModule +from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.overrides.base import warning_cache +from pytorch_lightning.overrides.data_parallel import ( + LightningParallelModule, + python_scalar_to_tensor, + unsqueeze_scalar_tensor, +) from pytorch_lightning.trainer.states import RunningStage +from tests.base import BoringModel -def test_lightning_distributed_module_methods(): - """ Test that the LightningDistributedModule redirects .forward() to the LightningModule methods. """ +@pytest.mark.parametrize("wrapper_class", [ + LightningParallelModule, + LightningDistributedModule, +]) +def test_lightning_wrapper_module_methods(wrapper_class): + """ Test that the LightningWrapper redirects .forward() to the LightningModule methods. """ pl_module = MagicMock() - dist_module = LightningDistributedModule(pl_module) + wrapped_module = wrapper_class(pl_module) batch = torch.rand(5) batch_idx = 3 pl_module.running_stage = RunningStage.TRAINING - dist_module(batch, batch_idx) + wrapped_module(batch, batch_idx) pl_module.training_step.assert_called_with(batch, batch_idx) pl_module.running_stage = RunningStage.TESTING - dist_module(batch, batch_idx) + wrapped_module(batch, batch_idx) pl_module.test_step.assert_called_with(batch, batch_idx) pl_module.running_stage = RunningStage.EVALUATING - dist_module(batch, batch_idx) + wrapped_module(batch, batch_idx) pl_module.validation_step.assert_called_with(batch, batch_idx) + pl_module.running_stage = None + wrapped_module(batch) + pl_module.predict.assert_called_with(batch) -def test_lightning_distributed_module_warn_none_output(): - """ Test that the LightningDistributedModule warns about forgotten return statement. """ + +@pytest.mark.parametrize("wrapper_class", [ + LightningParallelModule, + LightningDistributedModule, +]) +def test_lightning_wrapper_module_warn_none_output(wrapper_class): + """ Test that the LightningWrapper module warns about forgotten return statement. """ + warning_cache.clear() pl_module = MagicMock() - dist_module = LightningDistributedModule(pl_module) + wrapped_module = wrapper_class(pl_module) pl_module.training_step.return_value = None pl_module.validation_step.return_value = None @@ -39,12 +60,95 @@ def test_lightning_distributed_module_warn_none_output(): with pytest.warns(UserWarning, match="Your training_step returned None"): pl_module.running_stage = RunningStage.TRAINING - dist_module() + wrapped_module() with pytest.warns(UserWarning, match="Your test_step returned None"): pl_module.running_stage = RunningStage.TESTING - dist_module() + wrapped_module() with pytest.warns(UserWarning, match="Your validation_step returned None"): pl_module.running_stage = RunningStage.EVALUATING - dist_module() + wrapped_module() + + with pytest.warns(None) as record: + pl_module.running_stage = None + wrapped_module() + assert not record + + +@pytest.mark.parametrize("inp,expected", [ + [torch.tensor(1.0), torch.tensor([1.0])], + [torch.tensor([2.0]), torch.tensor([2.0])], + [torch.ones(3, 4, 5), torch.ones(3, 4, 5)], +]) +def test_unsqueeze_scalar_tensor(inp, expected): + """ Test that the utility function unsqueezes only scalar tensors. """ + assert torch.all(unsqueeze_scalar_tensor(inp).eq(expected)) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-gpu machine") +def test_lightning_parallel_module_unsqueeze_scalar(): + """ Test that LightningParallelModule takes care of un-squeezeing 0-dim tensors. """ + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + output = super().training_step(batch, batch_idx) + loss = output["loss"] + loss = loss.squeeze() + assert loss.dim() == 0 + # PyTorch usually warns about 0-dim tensors returned in DP + return {"loss": loss} + + model = TestModel() + model.running_stage = RunningStage.TRAINING + batch = torch.rand(2, 32).cuda() + batch_idx = 0 + + wrapped_model = LightningParallelModule(model).cuda() + dp_module = DataParallel(wrapped_model, device_ids=[0, 1]) + + output = wrapped_model(batch, batch_idx) + assert output["loss"].dim() == 1 + + with pytest.warns(None) as record: + output = dp_module(batch, batch_idx) + + assert output["loss"].dim() == 1 + assert not record + + +@pytest.mark.parametrize("inp,expected", [ + [1.0, torch.tensor([1.0])], + [2, torch.tensor([2.0])], + [True, torch.tensor([True])], +]) +def test_python_scalar_to_tensor(inp, expected): + assert torch.all(python_scalar_to_tensor(inp).eq(expected)) + + +@pytest.mark.parametrize("device", [ + torch.device("cpu"), + torch.device("cuda", 0) +]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_lightning_parallel_module_python_scalar_conversion(device): + """ Test that LightningParallelModule can convert Python scalars to tensors. """ + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + output = super().training_step(batch, batch_idx) + # PyTorch DP does not support Python scalars, Lightning converts them to tensors + output.update({"python scalar": 12.3}) + return output + + model = TestModel() + model.to(device) + model.running_stage = RunningStage.TRAINING + batch = torch.rand(2, 32).to(device) + batch_idx = 0 + + wrapped_model = LightningParallelModule(model) + output = wrapped_model(batch, batch_idx) + assert output["python scalar"] == torch.tensor([12.3], device=device)