Skip to content

Commit

Permalink
[bug-fix] Call transfer_batch_to_device in DDPlugin (#5195)
Browse files Browse the repository at this point in the history
* hacking out

* update

* remove useless on_before_forward

* update

* remove overriden

* iremove os

* use on_before_forward

* resolve flake8

* add test

* update

* add single_process_per_device

* resolve flake8

* update

* resolve

* update

* update

* update

* add comment

* resolve bug with sharded

* update

* remove property

* update

* resolve test

* resolve bug

* update on comments

* update doc

* Update pytorch_lightning/core/hooks.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* update on comments

* Update pytorch_lightning/plugins/ddp_plugin.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update pytorch_lightning/plugins/ddp_plugin.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* resolve pep8

* add device_ids to pipe

* update on comments

* update

* resolve

* update

* update

* update

Co-authored-by: Ubuntu <ubuntu@ip-172-31-62-109.ec2.internal>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
  • Loading branch information
4 people authored Jan 8, 2021
1 parent 72525f0 commit d510707
Show file tree
Hide file tree
Showing 15 changed files with 90 additions and 34 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `transfer_batch_to_device` for DDP with `len(devices_ids) == 1` ([#5195](https://github.com/PyTorchLightning/pytorch-lightning/pull/5195))



## [1.1.3] - 2021-01-05

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available

if HYDRA_AVAILABLE:
Expand Down Expand Up @@ -213,6 +213,7 @@ def ddp_train(self, process_idx, mp_queue, model):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities.distributed import (
all_gather_ddp_if_available,
find_free_network_port,
Expand Down Expand Up @@ -314,6 +314,7 @@ def ddp_train(self, process_idx, model):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities.distributed import (
all_gather_ddp_if_available,
find_free_network_port,
Expand Down Expand Up @@ -241,6 +241,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available

if HYDRA_AVAILABLE:
Expand Down Expand Up @@ -205,6 +205,7 @@ def ddp_train(self, process_idx, model):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import (
Expand Down Expand Up @@ -273,6 +273,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from typing import Any, Dict, List, Optional, Union

import torch
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn


class ModelHooks:
"""Hooks to be used in LightningModule."""
Expand Down Expand Up @@ -539,9 +540,9 @@ def transfer_batch_to_device(self, batch, device)
any other device than the one passed in as argument (unless you know what you are doing).
Note:
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
This hook only runs on single GPU training and DDP.
If you need multi-GPU support for your custom batch objects in ``dp`` or ``ddp2``,
you need to define your custom :class:`~torch.nn.parallel.DistributedDataParallel` or
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
Expand Down
26 changes: 14 additions & 12 deletions pytorch_lightning/plugins/ddp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Union

import torch
import torch.distributed as torch_distrib
from torch.optim import Optimizer

Expand Down Expand Up @@ -47,7 +48,7 @@ def configure_ddp(
def configure_ddp(self, model, device_ids):
model = LightningDistributedDataParallel(
model, device_ids=device_ids, find_unused_parameters=True
model, device_ids=device_ids, find_unused_parameters=False
)
return model
Expand All @@ -59,9 +60,9 @@ def configure_ddp(self, model, device_ids):
the model wrapped in LightningDistributedDataParallel
"""
# if unset, default `find_unused_parameters` `True`
# if unset, default `find_unused_parameters` `False`
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
"find_unused_parameters", True
"find_unused_parameters", False
)
model = LightningDistributedDataParallel(
model,
Expand Down Expand Up @@ -91,22 +92,23 @@ def init_ddp_connection(
torch_backend, rank=global_rank, world_size=world_size
)

@property
def is_running_single_process_per_device(self) -> bool:
# objects do not need to be scattered in single process per device, move objects upfront to device
# This property is used in ``self.on_before_forward`` function.
return self.device_ids is not None and len(self.device_ids) == 1

def on_before_forward(self, model: LightningModule, *args):
"""
Override to handle custom input to device logic. For DDP, no logic is required as this is handled internally
within the DDP wrapper.
Example::
def on_before_forward(self, model, *args):
batch, batch_idx = args
return batch.to(model.device)
Override to handle custom edge case.
Args:
args: Inputs to the model.
model: Model to train.
Returns: args moved to correct device if needed.
"""
if self.is_running_single_process_per_device:
args = model.transfer_batch_to_device(args, model.device)
return args

def optimizer_state(self, optimizer: Optimizer) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/ddp_sequential_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from torch import nn
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import LightningModule
from pytorch_lightning import _logger as log
from pytorch_lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
Expand Down
7 changes: 2 additions & 5 deletions pytorch_lightning/plugins/sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union, Any
from typing import Any, List, Optional, Union

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, AMPType, rank_zero_only
from pytorch_lightning.utilities import AMPType, FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if FAIRSCALE_AVAILABLE:
Expand All @@ -42,9 +42,6 @@ def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]:
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)

def on_before_forward(self, model: LightningModule, *args):
return model.transfer_batch_to_device(args, model.trainer.root_gpu)

def _check_fairscale(self):
if not FAIRSCALE_AVAILABLE:
raise MisconfigurationException(
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
from copy import copy, deepcopy

Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
return function(data, *args, **kwargs)

# Recursively apply to collection items
elif isinstance(data, Mapping):
if isinstance(data, Mapping):
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs)
for k, v in data.items()})
elif isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple

if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data))
elif isinstance(data, Sequence) and not isinstance(data, str):

if isinstance(data, Sequence) and not isinstance(data, str):
return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])

# data is neither of dtype, nor a collection
Expand Down
48 changes: 46 additions & 2 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from unittest.mock import MagicMock

import pytest
import torch
from unittest.mock import MagicMock

from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
from tests.base import EvalModelTemplate, BoringModel
from tests.base import BoringModel, EvalModelTemplate, RandomDataset


@pytest.mark.parametrize('max_steps', [1, 2, 3])
Expand Down Expand Up @@ -124,6 +125,49 @@ def transfer_batch_to_device(self, data, device):
assert batch_gpu.samples.device == batch_gpu.targets.device == expected


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
def test_transfer_batch_hook_ddp(tmpdir):
"""
Test custom data are properly moved to the right device using ddp
"""

class CustomBatch:

def __init__(self, data):
self.samples = data[0]

def to(self, device, **kwargs):
self.samples = self.samples.to(device, **kwargs)
return self

def collate_fn(batch):
return CustomBatch(batch)

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
assert batch.samples.device == self.device
assert isinstance(batch_idx, int)

def train_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64), collate_fn=collate_fn)

model = TestModel()
model.validation_step = None
model.training_epoch_end = None
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=0,
max_epochs=1,
weights_summary=None,
accelerator="ddp",
gpus=2,
)
trainer.fit(model)


@pytest.mark.parametrize(
'max_epochs,batch_idx_',
[(2, 5), (3, 8), (4, 12)]
Expand Down
4 changes: 3 additions & 1 deletion tests/models/test_sync_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import torch.nn as nn
import torch.nn.functional as F

from pytorch_lightning import Trainer, seed_everything, LightningModule
from pytorch_lightning import LightningModule, seed_everything, Trainer
from pytorch_lightning.core.step_result import TrainResult
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.utilities import FLOAT16_EPSILON
from tests.base.datamodules import MNISTDataModule
from tests.base.develop_utils import set_random_master_port
Expand Down Expand Up @@ -108,6 +109,7 @@ def test_sync_batchnorm_ddp(tmpdir):
sync_batchnorm=True,
num_sanity_val_steps=0,
replace_sampler_ddp=False,
plugins=[DDPPlugin(find_unused_parameters=True)]
)

result = trainer.fit(model, dm)
Expand Down
1 change: 1 addition & 0 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequent
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic
python ${DEFAULTS} tests/trainer/logging_tests/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp
python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp

0 comments on commit d510707

Please sign in to comment.