diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d93ef9373af1..8ebe1efa73b31 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -135,6 +135,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `transfer_batch_to_device` for DDP with `len(devices_ids) == 1` ([#5195](https://github.com/PyTorchLightning/pytorch-lightning/pull/5195)) + + ## [1.1.3] - 2021-01-05 diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index a5e8d720ce186..a61db38ea772c 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -210,6 +210,7 @@ def ddp_train(self, process_idx, mp_queue, model): def configure_ddp( self, model: LightningModule, device_ids: List[int] ) -> DistributedDataParallel: + self.ddp_plugin.device_ids = device_ids model = self.ddp_plugin.configure_ddp(model, device_ids) return model diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 56f6eaa2223a3..081d66c79eeeb 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -315,6 +315,7 @@ def ddp_train(self, process_idx, model): def configure_ddp( self, model: LightningModule, device_ids: List[int] ) -> DistributedDataParallel: + self.ddp_plugin.device_ids = device_ids model = self.ddp_plugin.configure_ddp(model, device_ids) return model diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index b15b9e8062257..afbdeed2b3046 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -239,6 +239,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): def configure_ddp( self, model: LightningModule, device_ids: List[int] ) -> DistributedDataParallel: + self.ddp_plugin.device_ids = device_ids model = self.ddp_plugin.configure_ddp(model, device_ids) return model diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index cf6aad9999223..c708c5e106930 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -199,6 +199,7 @@ def ddp_train(self, process_idx, model): def configure_ddp( self, model: LightningModule, device_ids: List[int] ) -> DistributedDataParallel: + self.ddp_plugin.device_ids = device_ids model = self.ddp_plugin.configure_ddp(model, device_ids) return model diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index e23943e9262f8..3cd79700efd91 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -271,6 +271,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): def configure_ddp( self, model: LightningModule, device_ids: List[int] ) -> DistributedDataParallel: + self.ddp_plugin.device_ids = device_ids model = self.ddp_plugin.configure_ddp(model, device_ids) return model diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index f27c18513831f..46bf5398c997f 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -562,9 +562,9 @@ def transfer_batch_to_device(self, batch, device) any other device than the one passed in as argument (unless you know what you are doing). Note: - This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support - for your custom batch objects, you need to define your custom - :class:`~torch.nn.parallel.DistributedDataParallel` and + This hook only runs on single GPU training and DDP (no data-parallel). + If you need multi-GPU support for your custom batch objects in ``dp`` or ``ddp2``, + you need to define your custom :class:`~torch.nn.parallel.DistributedDataParallel` or override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`. See Also: diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index f0da9e5ff1a2d..6fa19069937b6 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -110,22 +110,23 @@ def init_ddp_connection( torch_backend, rank=global_rank, world_size=world_size ) + @property + def is_running_single_process_per_device(self) -> bool: + # objects do not need to be scattered in single process per device, move objects upfront to device + # This property is used in ``self.on_before_forward`` function. + return self.device_ids is not None and len(self.device_ids) == 1 + def on_before_forward(self, model: LightningModule, *args): """ - Override to handle custom input to device logic. For DDP, no logic is required as this is handled internally - within the DDP wrapper. - - Example:: - - def on_before_forward(self, model, *args): - batch, batch_idx = args - return batch.to(model.device) + Override to handle custom edge case. Args: args: Inputs to the model. model: Model to train. Returns: args moved to correct device if needed. """ + if self.is_running_single_process_per_device: + args = model.transfer_batch_to_device(args, model.device) return args def optimizer_state(self, optimizer: Optimizer) -> dict: diff --git a/pytorch_lightning/plugins/ddp_sequential_plugin.py b/pytorch_lightning/plugins/ddp_sequential_plugin.py index 82250d1ed9fdd..2ad1949bc2f7c 100644 --- a/pytorch_lightning/plugins/ddp_sequential_plugin.py +++ b/pytorch_lightning/plugins/ddp_sequential_plugin.py @@ -19,8 +19,8 @@ from torch import nn from torch.nn.parallel import DistributedDataParallel -from pytorch_lightning import LightningModule from pytorch_lightning import _logger as log +from pytorch_lightning import LightningModule from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index ec1500ca7abf4..fbba43a050a57 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -42,9 +42,6 @@ def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]: optimizer.consolidate_state_dict() return self._optim_state_dict(optimizer) - def on_before_forward(self, model: LightningModule, *args): - return model.transfer_batch_to_device(args, model.trainer.root_gpu) - def _check_fairscale(self): if not _FAIRSCALE_AVAILABLE: raise MisconfigurationException( diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index fd610ebcb0c8d..df2375336c5cd 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -77,12 +77,14 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable return function(data, *args, **kwargs) # Recursively apply to collection items - elif isinstance(data, Mapping): + if isinstance(data, Mapping): return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()}) - elif isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple + + if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data)) - elif isinstance(data, Sequence) and not isinstance(data, str): + + if isinstance(data, Sequence) and not isinstance(data, str): return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data]) # data is neither of dtype, nor a collection diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 1f25d46f82944..a25a8181e763a 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import os from unittest.mock import MagicMock import pytest @@ -20,7 +21,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator from pytorch_lightning.trainer.states import TrainerState -from tests.base import BoringModel, EvalModelTemplate +from tests.base import BoringModel, EvalModelTemplate, RandomDataset @pytest.mark.parametrize('max_steps', [1, 2, 3]) @@ -125,6 +126,49 @@ def transfer_batch_to_device(self, data, device): assert batch_gpu.samples.device == batch_gpu.targets.device == expected +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +def test_transfer_batch_hook_ddp(tmpdir): + """ + Test custom data are properly moved to the right device using ddp + """ + + class CustomBatch: + + def __init__(self, data): + self.samples = data[0] + + def to(self, device, **kwargs): + self.samples = self.samples.to(device, **kwargs) + return self + + def collate_fn(batch): + return CustomBatch(batch) + + class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + assert batch.samples.device == self.device + assert isinstance(batch_idx, int) + + def train_dataloader(self): + return torch.utils.data.DataLoader(RandomDataset(32, 64), collate_fn=collate_fn) + + model = TestModel() + model.validation_step = None + model.training_epoch_end = None + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=0, + max_epochs=1, + weights_summary=None, + accelerator="ddp", + gpus=2, + ) + trainer.fit(model) + + @pytest.mark.parametrize( 'max_epochs,batch_idx_', [(2, 5), (3, 8), (4, 12)] diff --git a/tests/models/test_sync_batchnorm.py b/tests/models/test_sync_batchnorm.py index fe00acff62624..05ffded86699c 100644 --- a/tests/models/test_sync_batchnorm.py +++ b/tests/models/test_sync_batchnorm.py @@ -17,8 +17,8 @@ import torch.nn.functional as F from pytorch_lightning import LightningModule, seed_everything, Trainer -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.utilities import FLOAT16_EPSILON from tests.base.datamodules import MNISTDataModule from tests.base.develop_utils import set_random_master_port diff --git a/tests/special_tests.sh b/tests/special_tests.sh index ea14841c74bad..9d27a5786160f 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -22,4 +22,5 @@ python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequent python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection # python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance -python ${DEFAULTS} tests/trainer/logging_process/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp +python ${DEFAULTS} tests/trainer/logging_tests/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp +python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp