From 0d1fbd276306c7957ff6906f6dd85e578481162b Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 7 Dec 2020 13:34:59 -0500 Subject: [PATCH 01/22] all_gather --- pytorch_lightning/accelerators/accelerator.py | 14 +++ pytorch_lightning/core/lightning.py | 18 ++++ tests/core/test_lightning_all_gather.py | 85 +++++++++++++++++++ 3 files changed, 117 insertions(+) create mode 100644 tests/core/test_lightning_all_gather.py diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index c5b744c3384ec..3d03b79c16459 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -175,6 +175,20 @@ def sync_tensor(self, """ raise NotImplementedError() + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + raise NotImplementedError() + def optimizer_state(self, optimizer: Optimizer) -> dict: """ Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f1a0c725e2b12..249075792191c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -365,6 +365,24 @@ def __auto_choose_log_on_epoch(self, on_epoch): return on_epoch + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + r""" + Allows users to call ``self.all_gather()`` from the LightningModule, thus making + the ```all_gather``` operation accelerator agnostic. + + ```all_gather``` is a function provided by accelerators to gather a tensor from several + distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return self.trainer.accelerator_backend.all_gather(tensor, group=group, sync_grads=sync_grads) + def forward(self, *args, **kwargs): r""" Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define diff --git a/tests/core/test_lightning_all_gather.py b/tests/core/test_lightning_all_gather.py new file mode 100644 index 0000000000000..53d7297bac9ce --- /dev/null +++ b/tests/core/test_lightning_all_gather.py @@ -0,0 +1,85 @@ +import os +import pytest +import torch +import torch.nn as nn + +from tests.base.boring_model import BoringModel +from tests.base.develop_utils import set_random_master_port +from pytorch_lightning import Trainer, seed_everything + + +class AllGatherModel(BoringModel): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + self.layer1 = torch.nn.Linear(32, 2) + self.layer2 = torch.nn.Linear(32, 2) + self.layer3 = torch.nn.Linear(32, 2) + self.layer4 = torch.nn.Linear(32, 2) + + def forward(self, x): + tensor1 = self.layer1(x) + tensor2 = self.layer2(x) + + tensor1_gathered = self.all_gather(tensor1) + tensor2_gathered = self.all_gather(tensor2) + + assert torch.sum(tensor1_gathered[self.global_rank] - tensor1) == 0 + assert torch.sum(tensor2_gathered[self.global_rank] - tensor2) == 0 + + # with grads + tensor3 = self.layer3(x) + tensor4 = self.layer4(x) + + tensor3_gathered = self.all_gather(tensor3) + tensor4_gathered = self.all_gather(tensor4) + + assert torch.sum(tensor3_gathered[self.global_rank] - tensor3) == 0 + assert torch.sum(tensor4_gathered[self.global_rank] - tensor4) == 0 + + # test for grads + + return self.layer(x) + + +# test for ddp backends +def setup_ddp(rank, world_size): + """ Setup ddp enviroment """ + os.environ["MASTER_ADDR"] = 'localhost' + os.environ['MASTER_PORT'] = '8088' + + if torch.distributed.is_available() and sys.platform not in ('win32', 'cygwin'): + torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) + + +def _test_all_gather_ddp(rank, world_size): + setup_ddp(rank, world_size) + +# test horovod + + +# test tpu + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +@pytest.mark.parametrize("accelerator", ['ddp', 'horovod', 'tpu']) +def test_all_gather(accelerator): + gpus = 2 + + seed_everything(234) + set_random_master_port() + + model = AllGatherModel() + train_dataloader = model.train_dataloader() + + trainer = Trainer( + gpus=gpus, + accelerator=accelerator, + max_epochs=1, + max_steps=3, + num_sanity_val_steps=0, + ) + + result = trainer.fit(model, train_dataloader) + assert result == 1, "All gather op fails in Lightning Module" From 5bb4b5484cf522ca9da1fefe136343765c8400e6 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 7 Dec 2020 14:22:08 -0500 Subject: [PATCH 02/22] ddp --- .../accelerators/ddp2_accelerator.py | 16 +++- .../accelerators/ddp_accelerator.py | 15 +++ .../accelerators/ddp_cpu_spawn_accelerator.py | 15 +++ .../accelerators/ddp_hpc_accelerator.py | 16 +++- .../accelerators/ddp_spawn_accelerator.py | 15 +++ .../accelerators/horovod_accelerator.py | 14 +++ .../accelerators/tpu_accelerator.py | 15 +++ pytorch_lightning/utilities/distributed.py | 93 +++++++++++++++++++ tests/core/test_lightning_all_gather.py | 25 +---- 9 files changed, 201 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index f43866881cabb..e5cdecd2442dd 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -24,7 +24,7 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType -from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available +from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available if HYDRA_AVAILABLE: from hydra.core.hydra_config import HydraConfig @@ -217,5 +217,19 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 687b5c21874fb..24f662963ee3f 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -28,6 +28,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType +from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything @@ -315,5 +316,19 @@ def sync_tensor(self, """ return sync_ddp_if_available(tensor, group, reduce_op) + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 982da2f53216b..9cd68cd827334 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -30,6 +30,7 @@ rank_zero_only, rank_zero_warn, sync_ddp_if_available, + all_gather_ddp_if_available, ) if HYDRA_AVAILABLE: @@ -249,5 +250,19 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 28817c6845f5b..d13a758cbb838 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -24,7 +24,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType -from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available +from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available if HYDRA_AVAILABLE: from hydra.core.hydra_config import HydraConfig @@ -211,5 +211,19 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index a06d0b82d6d15..18e90b473e01b 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -33,6 +33,7 @@ rank_zero_only, rank_zero_warn, sync_ddp_if_available, + all_gather_ddp_if_available, ) from pytorch_lightning.utilities.seed import seed_everything @@ -276,5 +277,19 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 460f5a83d2582..8353cf89bd278 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -206,3 +206,17 @@ def sync_tensor(self, # sync all processes before reduction hvd.join() return hvd.allreduce(tensor, op=reduce_op) + + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_horovod_if_available(tensor, group=group, sync_grads=sync_grads) diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index cd6b99fa64eef..fab1c8018abd8 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -33,6 +33,7 @@ ) from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.distributed import all_gather_tpu_if_available if TPU_AVAILABLE: import torch_xla @@ -353,6 +354,20 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return tensor + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_tpu_if_available(tensor, group=group, sync_grads=sync_grads) + @property def norm_clipping_epsilon(self): return 1e-6 diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index ffa1be87cd3ca..e2cae5dcd20ae 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -155,3 +155,96 @@ def sync_ddp( result = result / torch.distributed.get_world_size(group) return result + + +class AllGatherGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor): + ctx.batch_size = tensor.shape[0] + + gathered_tensor = [ + torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + + torch.distributed.all_gather(gathered_tensor, tensor) + gathered_tensor = torch.stack(gathered_tensor, dim=0) + + return gathered_tensor + + @staticmethod + def backward(ctx, *grad_output): + #grad_input = grad_output.clone() + print(grad_output.shape) + exit(-1) + torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) + + return grad_input[ + torch.distributed.get_rank() * ctx.batch_size:(torch.distributed.get_rank() + 1) * ctx.batch_size + ] + + +def all_gather_ddp_if_available( + tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False +) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if sync_grads: + return AllGatherGrad.apply(tensor, group) + else: + gathered_tensor = [ + torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + + torch.distributed.all_gather(gathered_tensor, tensor) + gathered_tensor = torch.cat(gathered_tensor, 0) + + return gathered_tensor + return tensor + + +def all_gather_tpu_if_available( + tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False +) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return sync_ddp(result, group=group, reduce_op=reduce_op) + return result + + +def all_gather_horovod_if_available( + tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False +) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return sync_ddp(result, group=group, reduce_op=reduce_op) + return result \ No newline at end of file diff --git a/tests/core/test_lightning_all_gather.py b/tests/core/test_lightning_all_gather.py index 53d7297bac9ce..8cc8484feed9e 100644 --- a/tests/core/test_lightning_all_gather.py +++ b/tests/core/test_lightning_all_gather.py @@ -19,6 +19,7 @@ def __init__(self): self.layer4 = torch.nn.Linear(32, 2) def forward(self, x): + # no grad cases tensor1 = self.layer1(x) tensor2 = self.layer2(x) @@ -28,7 +29,7 @@ def forward(self, x): assert torch.sum(tensor1_gathered[self.global_rank] - tensor1) == 0 assert torch.sum(tensor2_gathered[self.global_rank] - tensor2) == 0 - # with grads + # with grad cases tensor3 = self.layer3(x) tensor4 = self.layer4(x) @@ -43,27 +44,9 @@ def forward(self, x): return self.layer(x) -# test for ddp backends -def setup_ddp(rank, world_size): - """ Setup ddp enviroment """ - os.environ["MASTER_ADDR"] = 'localhost' - os.environ['MASTER_PORT'] = '8088' - - if torch.distributed.is_available() and sys.platform not in ('win32', 'cygwin'): - torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) - - -def _test_all_gather_ddp(rank, world_size): - setup_ddp(rank, world_size) - -# test horovod - - -# test tpu - - +# TODO: horovod and TPU @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") -@pytest.mark.parametrize("accelerator", ['ddp', 'horovod', 'tpu']) +@pytest.mark.parametrize("accelerator", ['ddp', 'ddp_cpu', 'ddp_spawn']) def test_all_gather(accelerator): gpus = 2 From db652cfb4d4e69c1b88c97307cd23e45497dfd69 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 7 Dec 2020 14:36:21 -0500 Subject: [PATCH 03/22] horovod --- pytorch_lightning/accelerators/horovod_accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 8353cf89bd278..b440eb915c61a 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -19,7 +19,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.utilities import AMPType, HOROVOD_AVAILABLE -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import rank_zero_only, all_gather_horovod_if_available if HOROVOD_AVAILABLE: import horovod.torch as hvd From 0a821f414b0b01e339ba4431170ca472d1cea03c Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 7 Dec 2020 15:15:26 -0500 Subject: [PATCH 04/22] grad tests --- pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/distributed.py | 12 +++- tests/core/test_lightning_all_gather.py | 68 ---------------------- tests/utilities/test_all_gather_grad.py | 33 +++++++++++ 4 files changed, 43 insertions(+), 71 deletions(-) delete mode 100644 tests/core/test_lightning_all_gather.py create mode 100644 tests/utilities/test_all_gather_grad.py diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 1e2eeea9f456c..c28139b6f4c84 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -21,6 +21,7 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.distributed import AllGatherGrad from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index e2cae5dcd20ae..d87863daf5120 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -159,14 +159,15 @@ def sync_ddp( class AllGatherGrad(torch.autograd.Function): @staticmethod - def forward(ctx, tensor): + def forward(ctx, tensor, group=None): ctx.batch_size = tensor.shape[0] + ctx.group = group gathered_tensor = [ torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) ] - torch.distributed.all_gather(gathered_tensor, tensor) + torch.distributed.all_gather(gathered_tensor, tensor, group=group) gathered_tensor = torch.stack(gathered_tensor, dim=0) return gathered_tensor @@ -176,7 +177,12 @@ def backward(ctx, *grad_output): #grad_input = grad_output.clone() print(grad_output.shape) exit(-1) - torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) + torch.distributed.all_reduce( + grad_input, + op=torch.distributed.ReduceOp.SUM, + async_op=False, + group=ctx.group + ) return grad_input[ torch.distributed.get_rank() * ctx.batch_size:(torch.distributed.get_rank() + 1) * ctx.batch_size diff --git a/tests/core/test_lightning_all_gather.py b/tests/core/test_lightning_all_gather.py deleted file mode 100644 index 8cc8484feed9e..0000000000000 --- a/tests/core/test_lightning_all_gather.py +++ /dev/null @@ -1,68 +0,0 @@ -import os -import pytest -import torch -import torch.nn as nn - -from tests.base.boring_model import BoringModel -from tests.base.develop_utils import set_random_master_port -from pytorch_lightning import Trainer, seed_everything - - -class AllGatherModel(BoringModel): - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(32, 2) - - self.layer1 = torch.nn.Linear(32, 2) - self.layer2 = torch.nn.Linear(32, 2) - self.layer3 = torch.nn.Linear(32, 2) - self.layer4 = torch.nn.Linear(32, 2) - - def forward(self, x): - # no grad cases - tensor1 = self.layer1(x) - tensor2 = self.layer2(x) - - tensor1_gathered = self.all_gather(tensor1) - tensor2_gathered = self.all_gather(tensor2) - - assert torch.sum(tensor1_gathered[self.global_rank] - tensor1) == 0 - assert torch.sum(tensor2_gathered[self.global_rank] - tensor2) == 0 - - # with grad cases - tensor3 = self.layer3(x) - tensor4 = self.layer4(x) - - tensor3_gathered = self.all_gather(tensor3) - tensor4_gathered = self.all_gather(tensor4) - - assert torch.sum(tensor3_gathered[self.global_rank] - tensor3) == 0 - assert torch.sum(tensor4_gathered[self.global_rank] - tensor4) == 0 - - # test for grads - - return self.layer(x) - - -# TODO: horovod and TPU -@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") -@pytest.mark.parametrize("accelerator", ['ddp', 'ddp_cpu', 'ddp_spawn']) -def test_all_gather(accelerator): - gpus = 2 - - seed_everything(234) - set_random_master_port() - - model = AllGatherModel() - train_dataloader = model.train_dataloader() - - trainer = Trainer( - gpus=gpus, - accelerator=accelerator, - max_epochs=1, - max_steps=3, - num_sanity_val_steps=0, - ) - - result = trainer.fit(model, train_dataloader) - assert result == 1, "All gather op fails in Lightning Module" diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py new file mode 100644 index 0000000000000..e9c409506c645 --- /dev/null +++ b/tests/utilities/test_all_gather_grad.py @@ -0,0 +1,33 @@ +import os +import pytest +import torch +import torch.nn as nn + +from pytorch_lightning.utilities import AllGatherGrad + + +def setup_ddp(rank, world_size): + """ Setup ddp enviroment """ + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8088" + + if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): + torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) + + +def _test_all_gather(rank, world_size): + setup_ddp(rank, world_size) + + tensor1 = torch.randn(8) + tensor2 = torch.randn(8, 16, 32) + + tensor1_gathered = AllGatherGrad.apply(tensor1) + tensor2_gathered = AllGatherGrad.apply(tensor2) + + assert torch.sum(tensor1_gathered[rank] - tensor1) == 0 + assert torch.sum(tensor2_gathered[rank] - tensor2) == 0 + + +def test_all_gather(): + world_size = 3 + torch.multiprocessing.spawn(_test_all_gather, args=(world_size,), nprocs=world_size) From 2dd7680c06b6bde6e7fa2fa3bb7c18a586c5d67b Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 7 Dec 2020 17:20:25 -0500 Subject: [PATCH 05/22] fixed ddp --- pytorch_lightning/utilities/distributed.py | 17 +++++++---------- tests/utilities/test_all_gather_grad.py | 19 +++++++++++++++---- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index d87863daf5120..1d25549822667 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -159,7 +159,7 @@ def sync_ddp( class AllGatherGrad(torch.autograd.Function): @staticmethod - def forward(ctx, tensor, group=None): + def forward(ctx, tensor, group=torch.distributed.group.WORLD): ctx.batch_size = tensor.shape[0] ctx.group = group @@ -174,19 +174,16 @@ def forward(ctx, tensor, group=None): @staticmethod def backward(ctx, *grad_output): - #grad_input = grad_output.clone() - print(grad_output.shape) - exit(-1) + grad_output = torch.cat(grad_output) + torch.distributed.all_reduce( - grad_input, + grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group ) - return grad_input[ - torch.distributed.get_rank() * ctx.batch_size:(torch.distributed.get_rank() + 1) * ctx.batch_size - ] + return grad_output[torch.distributed.get_rank()] def all_gather_ddp_if_available( @@ -212,7 +209,7 @@ def all_gather_ddp_if_available( ] torch.distributed.all_gather(gathered_tensor, tensor) - gathered_tensor = torch.cat(gathered_tensor, 0) + gathered_tensor = torch.stack(gathered_tensor, dim=0) return gathered_tensor return tensor @@ -253,4 +250,4 @@ def all_gather_horovod_if_available( """ if torch.distributed.is_available() and torch.distributed.is_initialized(): return sync_ddp(result, group=group, reduce_op=reduce_op) - return result \ No newline at end of file + return result diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index e9c409506c645..762f635dbd969 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -1,5 +1,6 @@ import os import pytest +import sys import torch import torch.nn as nn @@ -18,16 +19,26 @@ def setup_ddp(rank, world_size): def _test_all_gather(rank, world_size): setup_ddp(rank, world_size) - tensor1 = torch.randn(8) - tensor2 = torch.randn(8, 16, 32) + tensor1 = torch.ones(8, requires_grad=True) + tensor2 = torch.ones((8, 16, 32), requires_grad=True) tensor1_gathered = AllGatherGrad.apply(tensor1) tensor2_gathered = AllGatherGrad.apply(tensor2) - assert torch.sum(tensor1_gathered[rank] - tensor1) == 0 - assert torch.sum(tensor2_gathered[rank] - tensor2) == 0 + tensor1_gathered = tensor1_gathered * rank + tensor2_gathered = tensor2_gathered * rank + tensor1_gathered.sum().backward() + tensor2_gathered.sum().backward() + grad1 = torch.zeros_like(tensor1.grad).fill_(torch.arange(world_size).sum().float()) + grad2 = torch.zeros_like(tensor2.grad).fill_(torch.arange(world_size).sum().float()) + + assert torch.allclose(grad1, tensor1.grad) + assert torch.allclose(grad2, tensor2.grad) + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") def test_all_gather(): world_size = 3 torch.multiprocessing.spawn(_test_all_gather, args=(world_size,), nprocs=world_size) From 9f031a0af7baf354a8fd92c7b06955c1d1014a53 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 7 Dec 2020 23:53:41 -0500 Subject: [PATCH 06/22] ddp fixed, removed tpu, horovod for now --- .../accelerators/horovod_accelerator.py | 16 +------- .../accelerators/tpu_accelerator.py | 15 -------- pytorch_lightning/utilities/distributed.py | 38 ------------------- tests/utilities/test_all_gather_grad.py | 6 +-- 4 files changed, 4 insertions(+), 71 deletions(-) diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index b440eb915c61a..460f5a83d2582 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -19,7 +19,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.utilities import AMPType, HOROVOD_AVAILABLE -from pytorch_lightning.utilities.distributed import rank_zero_only, all_gather_horovod_if_available +from pytorch_lightning.utilities.distributed import rank_zero_only if HOROVOD_AVAILABLE: import horovod.torch as hvd @@ -206,17 +206,3 @@ def sync_tensor(self, # sync all processes before reduction hvd.join() return hvd.allreduce(tensor, op=reduce_op) - - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - return all_gather_horovod_if_available(tensor, group=group, sync_grads=sync_grads) diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index fab1c8018abd8..cd6b99fa64eef 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -33,7 +33,6 @@ ) from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.distributed import all_gather_tpu_if_available if TPU_AVAILABLE: import torch_xla @@ -354,20 +353,6 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return tensor - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - return all_gather_tpu_if_available(tensor, group=group, sync_grads=sync_grads) - @property def norm_clipping_epsilon(self): return 1e-6 diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 1d25549822667..68a79e6676e84 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -213,41 +213,3 @@ def all_gather_ddp_if_available( return gathered_tensor return tensor - - -def all_gather_tpu_if_available( - tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False -) -> torch.Tensor: - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - return sync_ddp(result, group=group, reduce_op=reduce_op) - return result - - -def all_gather_horovod_if_available( - tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False -) -> torch.Tensor: - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - return sync_ddp(result, group=group, reduce_op=reduce_op) - return result diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 762f635dbd969..66e50776edd3f 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -16,7 +16,7 @@ def setup_ddp(rank, world_size): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) -def _test_all_gather(rank, world_size): +def _test_all_gather_ddp(rank, world_size): setup_ddp(rank, world_size) tensor1 = torch.ones(8, requires_grad=True) @@ -39,6 +39,6 @@ def _test_all_gather(rank, world_size): @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") -def test_all_gather(): +def test_all_gather_ddp(): world_size = 3 - torch.multiprocessing.spawn(_test_all_gather, args=(world_size,), nprocs=world_size) + torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size) From eec2ff3030a806bec0b814c07a20fb2c1a179d52 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 8 Dec 2020 00:02:09 -0500 Subject: [PATCH 07/22] changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ba46ebdc8520..c412f3f859e56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## Unreleased +### Added + +- Added `all_gather` method to `LightningModule` which allows gradient based tensor synchronizations for use-cases such as negative sampling. ([#5012](https://github.com/PyTorchLightning/pytorch-lightning/pull/5012)) + ### Fixed - Fixed `LoggerConnector` to have logged metrics on root device in DP ([#4138](https://github.com/PyTorchLightning/pytorch-lightning/pull/4138)) From fbfdb43b457fc138b3fc7fc669a7df755fae6b2b Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 8 Dec 2020 01:26:02 -0500 Subject: [PATCH 08/22] windows fix --- pytorch_lightning/utilities/distributed.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 68a79e6676e84..345a44922be53 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -22,11 +22,16 @@ if torch.distributed.is_available(): from torch.distributed import ReduceOp + from torch.distributed import group else: class ReduceOp: SUM = None + class group: + WORLD = None + + def rank_zero_only(fn): @wraps(fn) @@ -159,7 +164,7 @@ def sync_ddp( class AllGatherGrad(torch.autograd.Function): @staticmethod - def forward(ctx, tensor, group=torch.distributed.group.WORLD): + def forward(ctx, tensor, group=group.WORLD): ctx.batch_size = tensor.shape[0] ctx.group = group From ea1d3d82b9eacf09c07f774cecbaac039dce4899 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 8 Dec 2020 01:27:03 -0500 Subject: [PATCH 09/22] windows fix --- pytorch_lightning/utilities/distributed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 345a44922be53..60e6bddff29fd 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -27,7 +27,6 @@ class ReduceOp: SUM = None - class group: WORLD = None From f666020c769cf7a43f7664d124f9712ada0d7db9 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 8 Dec 2020 04:05:59 -0500 Subject: [PATCH 10/22] removed batch from ctx --- pytorch_lightning/utilities/distributed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 60e6bddff29fd..d0534d6764bd2 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -164,7 +164,6 @@ def sync_ddp( class AllGatherGrad(torch.autograd.Function): @staticmethod def forward(ctx, tensor, group=group.WORLD): - ctx.batch_size = tensor.shape[0] ctx.group = group gathered_tensor = [ From ab1a8649d0f1c4dc548dd44a11d2db5a356fff9f Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 7 Dec 2020 13:34:59 -0500 Subject: [PATCH 11/22] all_gather --- pytorch_lightning/accelerators/accelerator.py | 14 +++ pytorch_lightning/core/lightning.py | 18 ++++ tests/core/test_lightning_all_gather.py | 85 +++++++++++++++++++ 3 files changed, 117 insertions(+) create mode 100644 tests/core/test_lightning_all_gather.py diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index c5b744c3384ec..3d03b79c16459 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -175,6 +175,20 @@ def sync_tensor(self, """ raise NotImplementedError() + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + raise NotImplementedError() + def optimizer_state(self, optimizer: Optimizer) -> dict: """ Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 671084cb2fac7..5acec2b86722c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -365,6 +365,24 @@ def __auto_choose_log_on_epoch(self, on_epoch): return on_epoch + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + r""" + Allows users to call ``self.all_gather()`` from the LightningModule, thus making + the ```all_gather``` operation accelerator agnostic. + + ```all_gather``` is a function provided by accelerators to gather a tensor from several + distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return self.trainer.accelerator_backend.all_gather(tensor, group=group, sync_grads=sync_grads) + def forward(self, *args, **kwargs): r""" Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define diff --git a/tests/core/test_lightning_all_gather.py b/tests/core/test_lightning_all_gather.py new file mode 100644 index 0000000000000..53d7297bac9ce --- /dev/null +++ b/tests/core/test_lightning_all_gather.py @@ -0,0 +1,85 @@ +import os +import pytest +import torch +import torch.nn as nn + +from tests.base.boring_model import BoringModel +from tests.base.develop_utils import set_random_master_port +from pytorch_lightning import Trainer, seed_everything + + +class AllGatherModel(BoringModel): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + self.layer1 = torch.nn.Linear(32, 2) + self.layer2 = torch.nn.Linear(32, 2) + self.layer3 = torch.nn.Linear(32, 2) + self.layer4 = torch.nn.Linear(32, 2) + + def forward(self, x): + tensor1 = self.layer1(x) + tensor2 = self.layer2(x) + + tensor1_gathered = self.all_gather(tensor1) + tensor2_gathered = self.all_gather(tensor2) + + assert torch.sum(tensor1_gathered[self.global_rank] - tensor1) == 0 + assert torch.sum(tensor2_gathered[self.global_rank] - tensor2) == 0 + + # with grads + tensor3 = self.layer3(x) + tensor4 = self.layer4(x) + + tensor3_gathered = self.all_gather(tensor3) + tensor4_gathered = self.all_gather(tensor4) + + assert torch.sum(tensor3_gathered[self.global_rank] - tensor3) == 0 + assert torch.sum(tensor4_gathered[self.global_rank] - tensor4) == 0 + + # test for grads + + return self.layer(x) + + +# test for ddp backends +def setup_ddp(rank, world_size): + """ Setup ddp enviroment """ + os.environ["MASTER_ADDR"] = 'localhost' + os.environ['MASTER_PORT'] = '8088' + + if torch.distributed.is_available() and sys.platform not in ('win32', 'cygwin'): + torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) + + +def _test_all_gather_ddp(rank, world_size): + setup_ddp(rank, world_size) + +# test horovod + + +# test tpu + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +@pytest.mark.parametrize("accelerator", ['ddp', 'horovod', 'tpu']) +def test_all_gather(accelerator): + gpus = 2 + + seed_everything(234) + set_random_master_port() + + model = AllGatherModel() + train_dataloader = model.train_dataloader() + + trainer = Trainer( + gpus=gpus, + accelerator=accelerator, + max_epochs=1, + max_steps=3, + num_sanity_val_steps=0, + ) + + result = trainer.fit(model, train_dataloader) + assert result == 1, "All gather op fails in Lightning Module" From 993aa51c9dd0f83e78d8e9812c7ce7a14b7e5597 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 7 Dec 2020 14:22:08 -0500 Subject: [PATCH 12/22] ddp --- .../accelerators/ddp2_accelerator.py | 16 +++- .../accelerators/ddp_accelerator.py | 15 +++ .../accelerators/ddp_cpu_spawn_accelerator.py | 15 +++ .../accelerators/ddp_hpc_accelerator.py | 16 +++- .../accelerators/ddp_spawn_accelerator.py | 15 +++ .../accelerators/horovod_accelerator.py | 14 +++ .../accelerators/tpu_accelerator.py | 15 +++ pytorch_lightning/utilities/distributed.py | 93 +++++++++++++++++++ tests/core/test_lightning_all_gather.py | 25 +---- 9 files changed, 201 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index f43866881cabb..e5cdecd2442dd 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -24,7 +24,7 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType -from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available +from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available if HYDRA_AVAILABLE: from hydra.core.hydra_config import HydraConfig @@ -217,5 +217,19 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 687b5c21874fb..24f662963ee3f 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -28,6 +28,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType +from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything @@ -315,5 +316,19 @@ def sync_tensor(self, """ return sync_ddp_if_available(tensor, group, reduce_op) + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 982da2f53216b..9cd68cd827334 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -30,6 +30,7 @@ rank_zero_only, rank_zero_warn, sync_ddp_if_available, + all_gather_ddp_if_available, ) if HYDRA_AVAILABLE: @@ -249,5 +250,19 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 28817c6845f5b..d13a758cbb838 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -24,7 +24,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType -from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available +from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available if HYDRA_AVAILABLE: from hydra.core.hydra_config import HydraConfig @@ -211,5 +211,19 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index a06d0b82d6d15..18e90b473e01b 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -33,6 +33,7 @@ rank_zero_only, rank_zero_warn, sync_ddp_if_available, + all_gather_ddp_if_available, ) from pytorch_lightning.utilities.seed import seed_everything @@ -276,5 +277,19 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 460f5a83d2582..8353cf89bd278 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -206,3 +206,17 @@ def sync_tensor(self, # sync all processes before reduction hvd.join() return hvd.allreduce(tensor, op=reduce_op) + + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_horovod_if_available(tensor, group=group, sync_grads=sync_grads) diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index cd6b99fa64eef..fab1c8018abd8 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -33,6 +33,7 @@ ) from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.distributed import all_gather_tpu_if_available if TPU_AVAILABLE: import torch_xla @@ -353,6 +354,20 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return tensor + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_tpu_if_available(tensor, group=group, sync_grads=sync_grads) + @property def norm_clipping_epsilon(self): return 1e-6 diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index ffa1be87cd3ca..e2cae5dcd20ae 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -155,3 +155,96 @@ def sync_ddp( result = result / torch.distributed.get_world_size(group) return result + + +class AllGatherGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor): + ctx.batch_size = tensor.shape[0] + + gathered_tensor = [ + torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + + torch.distributed.all_gather(gathered_tensor, tensor) + gathered_tensor = torch.stack(gathered_tensor, dim=0) + + return gathered_tensor + + @staticmethod + def backward(ctx, *grad_output): + #grad_input = grad_output.clone() + print(grad_output.shape) + exit(-1) + torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) + + return grad_input[ + torch.distributed.get_rank() * ctx.batch_size:(torch.distributed.get_rank() + 1) * ctx.batch_size + ] + + +def all_gather_ddp_if_available( + tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False +) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if sync_grads: + return AllGatherGrad.apply(tensor, group) + else: + gathered_tensor = [ + torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + + torch.distributed.all_gather(gathered_tensor, tensor) + gathered_tensor = torch.cat(gathered_tensor, 0) + + return gathered_tensor + return tensor + + +def all_gather_tpu_if_available( + tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False +) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return sync_ddp(result, group=group, reduce_op=reduce_op) + return result + + +def all_gather_horovod_if_available( + tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False +) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return sync_ddp(result, group=group, reduce_op=reduce_op) + return result \ No newline at end of file diff --git a/tests/core/test_lightning_all_gather.py b/tests/core/test_lightning_all_gather.py index 53d7297bac9ce..8cc8484feed9e 100644 --- a/tests/core/test_lightning_all_gather.py +++ b/tests/core/test_lightning_all_gather.py @@ -19,6 +19,7 @@ def __init__(self): self.layer4 = torch.nn.Linear(32, 2) def forward(self, x): + # no grad cases tensor1 = self.layer1(x) tensor2 = self.layer2(x) @@ -28,7 +29,7 @@ def forward(self, x): assert torch.sum(tensor1_gathered[self.global_rank] - tensor1) == 0 assert torch.sum(tensor2_gathered[self.global_rank] - tensor2) == 0 - # with grads + # with grad cases tensor3 = self.layer3(x) tensor4 = self.layer4(x) @@ -43,27 +44,9 @@ def forward(self, x): return self.layer(x) -# test for ddp backends -def setup_ddp(rank, world_size): - """ Setup ddp enviroment """ - os.environ["MASTER_ADDR"] = 'localhost' - os.environ['MASTER_PORT'] = '8088' - - if torch.distributed.is_available() and sys.platform not in ('win32', 'cygwin'): - torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) - - -def _test_all_gather_ddp(rank, world_size): - setup_ddp(rank, world_size) - -# test horovod - - -# test tpu - - +# TODO: horovod and TPU @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") -@pytest.mark.parametrize("accelerator", ['ddp', 'horovod', 'tpu']) +@pytest.mark.parametrize("accelerator", ['ddp', 'ddp_cpu', 'ddp_spawn']) def test_all_gather(accelerator): gpus = 2 From 6fc03bfb5c0e86344bc2c7848401d2634a1920f6 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 7 Dec 2020 14:36:21 -0500 Subject: [PATCH 13/22] horovod --- pytorch_lightning/accelerators/horovod_accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 8353cf89bd278..b440eb915c61a 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -19,7 +19,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.utilities import AMPType, HOROVOD_AVAILABLE -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import rank_zero_only, all_gather_horovod_if_available if HOROVOD_AVAILABLE: import horovod.torch as hvd From 309a7e0d8242618ea9ec02fdaf9e55d4a4135bb7 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 7 Dec 2020 15:15:26 -0500 Subject: [PATCH 14/22] grad tests --- pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/distributed.py | 12 +++- tests/core/test_lightning_all_gather.py | 68 ---------------------- tests/utilities/test_all_gather_grad.py | 33 +++++++++++ 4 files changed, 43 insertions(+), 71 deletions(-) delete mode 100644 tests/core/test_lightning_all_gather.py create mode 100644 tests/utilities/test_all_gather_grad.py diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 1e2eeea9f456c..c28139b6f4c84 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -21,6 +21,7 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.distributed import AllGatherGrad from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index e2cae5dcd20ae..d87863daf5120 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -159,14 +159,15 @@ def sync_ddp( class AllGatherGrad(torch.autograd.Function): @staticmethod - def forward(ctx, tensor): + def forward(ctx, tensor, group=None): ctx.batch_size = tensor.shape[0] + ctx.group = group gathered_tensor = [ torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) ] - torch.distributed.all_gather(gathered_tensor, tensor) + torch.distributed.all_gather(gathered_tensor, tensor, group=group) gathered_tensor = torch.stack(gathered_tensor, dim=0) return gathered_tensor @@ -176,7 +177,12 @@ def backward(ctx, *grad_output): #grad_input = grad_output.clone() print(grad_output.shape) exit(-1) - torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) + torch.distributed.all_reduce( + grad_input, + op=torch.distributed.ReduceOp.SUM, + async_op=False, + group=ctx.group + ) return grad_input[ torch.distributed.get_rank() * ctx.batch_size:(torch.distributed.get_rank() + 1) * ctx.batch_size diff --git a/tests/core/test_lightning_all_gather.py b/tests/core/test_lightning_all_gather.py deleted file mode 100644 index 8cc8484feed9e..0000000000000 --- a/tests/core/test_lightning_all_gather.py +++ /dev/null @@ -1,68 +0,0 @@ -import os -import pytest -import torch -import torch.nn as nn - -from tests.base.boring_model import BoringModel -from tests.base.develop_utils import set_random_master_port -from pytorch_lightning import Trainer, seed_everything - - -class AllGatherModel(BoringModel): - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(32, 2) - - self.layer1 = torch.nn.Linear(32, 2) - self.layer2 = torch.nn.Linear(32, 2) - self.layer3 = torch.nn.Linear(32, 2) - self.layer4 = torch.nn.Linear(32, 2) - - def forward(self, x): - # no grad cases - tensor1 = self.layer1(x) - tensor2 = self.layer2(x) - - tensor1_gathered = self.all_gather(tensor1) - tensor2_gathered = self.all_gather(tensor2) - - assert torch.sum(tensor1_gathered[self.global_rank] - tensor1) == 0 - assert torch.sum(tensor2_gathered[self.global_rank] - tensor2) == 0 - - # with grad cases - tensor3 = self.layer3(x) - tensor4 = self.layer4(x) - - tensor3_gathered = self.all_gather(tensor3) - tensor4_gathered = self.all_gather(tensor4) - - assert torch.sum(tensor3_gathered[self.global_rank] - tensor3) == 0 - assert torch.sum(tensor4_gathered[self.global_rank] - tensor4) == 0 - - # test for grads - - return self.layer(x) - - -# TODO: horovod and TPU -@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") -@pytest.mark.parametrize("accelerator", ['ddp', 'ddp_cpu', 'ddp_spawn']) -def test_all_gather(accelerator): - gpus = 2 - - seed_everything(234) - set_random_master_port() - - model = AllGatherModel() - train_dataloader = model.train_dataloader() - - trainer = Trainer( - gpus=gpus, - accelerator=accelerator, - max_epochs=1, - max_steps=3, - num_sanity_val_steps=0, - ) - - result = trainer.fit(model, train_dataloader) - assert result == 1, "All gather op fails in Lightning Module" diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py new file mode 100644 index 0000000000000..e9c409506c645 --- /dev/null +++ b/tests/utilities/test_all_gather_grad.py @@ -0,0 +1,33 @@ +import os +import pytest +import torch +import torch.nn as nn + +from pytorch_lightning.utilities import AllGatherGrad + + +def setup_ddp(rank, world_size): + """ Setup ddp enviroment """ + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8088" + + if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): + torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) + + +def _test_all_gather(rank, world_size): + setup_ddp(rank, world_size) + + tensor1 = torch.randn(8) + tensor2 = torch.randn(8, 16, 32) + + tensor1_gathered = AllGatherGrad.apply(tensor1) + tensor2_gathered = AllGatherGrad.apply(tensor2) + + assert torch.sum(tensor1_gathered[rank] - tensor1) == 0 + assert torch.sum(tensor2_gathered[rank] - tensor2) == 0 + + +def test_all_gather(): + world_size = 3 + torch.multiprocessing.spawn(_test_all_gather, args=(world_size,), nprocs=world_size) From 519711da5077ed433ca85a2a5d64e06f37bc7cea Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 7 Dec 2020 17:20:25 -0500 Subject: [PATCH 15/22] fixed ddp --- pytorch_lightning/utilities/distributed.py | 17 +++++++---------- tests/utilities/test_all_gather_grad.py | 19 +++++++++++++++---- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index d87863daf5120..1d25549822667 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -159,7 +159,7 @@ def sync_ddp( class AllGatherGrad(torch.autograd.Function): @staticmethod - def forward(ctx, tensor, group=None): + def forward(ctx, tensor, group=torch.distributed.group.WORLD): ctx.batch_size = tensor.shape[0] ctx.group = group @@ -174,19 +174,16 @@ def forward(ctx, tensor, group=None): @staticmethod def backward(ctx, *grad_output): - #grad_input = grad_output.clone() - print(grad_output.shape) - exit(-1) + grad_output = torch.cat(grad_output) + torch.distributed.all_reduce( - grad_input, + grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group ) - return grad_input[ - torch.distributed.get_rank() * ctx.batch_size:(torch.distributed.get_rank() + 1) * ctx.batch_size - ] + return grad_output[torch.distributed.get_rank()] def all_gather_ddp_if_available( @@ -212,7 +209,7 @@ def all_gather_ddp_if_available( ] torch.distributed.all_gather(gathered_tensor, tensor) - gathered_tensor = torch.cat(gathered_tensor, 0) + gathered_tensor = torch.stack(gathered_tensor, dim=0) return gathered_tensor return tensor @@ -253,4 +250,4 @@ def all_gather_horovod_if_available( """ if torch.distributed.is_available() and torch.distributed.is_initialized(): return sync_ddp(result, group=group, reduce_op=reduce_op) - return result \ No newline at end of file + return result diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index e9c409506c645..762f635dbd969 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -1,5 +1,6 @@ import os import pytest +import sys import torch import torch.nn as nn @@ -18,16 +19,26 @@ def setup_ddp(rank, world_size): def _test_all_gather(rank, world_size): setup_ddp(rank, world_size) - tensor1 = torch.randn(8) - tensor2 = torch.randn(8, 16, 32) + tensor1 = torch.ones(8, requires_grad=True) + tensor2 = torch.ones((8, 16, 32), requires_grad=True) tensor1_gathered = AllGatherGrad.apply(tensor1) tensor2_gathered = AllGatherGrad.apply(tensor2) - assert torch.sum(tensor1_gathered[rank] - tensor1) == 0 - assert torch.sum(tensor2_gathered[rank] - tensor2) == 0 + tensor1_gathered = tensor1_gathered * rank + tensor2_gathered = tensor2_gathered * rank + tensor1_gathered.sum().backward() + tensor2_gathered.sum().backward() + grad1 = torch.zeros_like(tensor1.grad).fill_(torch.arange(world_size).sum().float()) + grad2 = torch.zeros_like(tensor2.grad).fill_(torch.arange(world_size).sum().float()) + + assert torch.allclose(grad1, tensor1.grad) + assert torch.allclose(grad2, tensor2.grad) + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") def test_all_gather(): world_size = 3 torch.multiprocessing.spawn(_test_all_gather, args=(world_size,), nprocs=world_size) From 2ab162dacac67175028f7d185b4bace24ba34749 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 7 Dec 2020 23:53:41 -0500 Subject: [PATCH 16/22] ddp fixed, removed tpu, horovod for now --- .../accelerators/horovod_accelerator.py | 16 +------- .../accelerators/tpu_accelerator.py | 15 -------- pytorch_lightning/utilities/distributed.py | 38 ------------------- tests/utilities/test_all_gather_grad.py | 6 +-- 4 files changed, 4 insertions(+), 71 deletions(-) diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index b440eb915c61a..460f5a83d2582 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -19,7 +19,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.utilities import AMPType, HOROVOD_AVAILABLE -from pytorch_lightning.utilities.distributed import rank_zero_only, all_gather_horovod_if_available +from pytorch_lightning.utilities.distributed import rank_zero_only if HOROVOD_AVAILABLE: import horovod.torch as hvd @@ -206,17 +206,3 @@ def sync_tensor(self, # sync all processes before reduction hvd.join() return hvd.allreduce(tensor, op=reduce_op) - - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - return all_gather_horovod_if_available(tensor, group=group, sync_grads=sync_grads) diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index fab1c8018abd8..cd6b99fa64eef 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -33,7 +33,6 @@ ) from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.distributed import all_gather_tpu_if_available if TPU_AVAILABLE: import torch_xla @@ -354,20 +353,6 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return tensor - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - return all_gather_tpu_if_available(tensor, group=group, sync_grads=sync_grads) - @property def norm_clipping_epsilon(self): return 1e-6 diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 1d25549822667..68a79e6676e84 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -213,41 +213,3 @@ def all_gather_ddp_if_available( return gathered_tensor return tensor - - -def all_gather_tpu_if_available( - tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False -) -> torch.Tensor: - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - return sync_ddp(result, group=group, reduce_op=reduce_op) - return result - - -def all_gather_horovod_if_available( - tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False -) -> torch.Tensor: - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - return sync_ddp(result, group=group, reduce_op=reduce_op) - return result diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 762f635dbd969..66e50776edd3f 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -16,7 +16,7 @@ def setup_ddp(rank, world_size): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) -def _test_all_gather(rank, world_size): +def _test_all_gather_ddp(rank, world_size): setup_ddp(rank, world_size) tensor1 = torch.ones(8, requires_grad=True) @@ -39,6 +39,6 @@ def _test_all_gather(rank, world_size): @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") -def test_all_gather(): +def test_all_gather_ddp(): world_size = 3 - torch.multiprocessing.spawn(_test_all_gather, args=(world_size,), nprocs=world_size) + torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size) From 23a477970ccadc3ea88b2c1e677de87bcfcfaa2c Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 8 Dec 2020 00:02:09 -0500 Subject: [PATCH 17/22] changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ba46ebdc8520..c412f3f859e56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## Unreleased +### Added + +- Added `all_gather` method to `LightningModule` which allows gradient based tensor synchronizations for use-cases such as negative sampling. ([#5012](https://github.com/PyTorchLightning/pytorch-lightning/pull/5012)) + ### Fixed - Fixed `LoggerConnector` to have logged metrics on root device in DP ([#4138](https://github.com/PyTorchLightning/pytorch-lightning/pull/4138)) From 3ca405430a5cd84b13c30b4a96f1653512c7cc79 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 8 Dec 2020 01:26:02 -0500 Subject: [PATCH 18/22] windows fix --- pytorch_lightning/utilities/distributed.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 68a79e6676e84..345a44922be53 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -22,11 +22,16 @@ if torch.distributed.is_available(): from torch.distributed import ReduceOp + from torch.distributed import group else: class ReduceOp: SUM = None + class group: + WORLD = None + + def rank_zero_only(fn): @wraps(fn) @@ -159,7 +164,7 @@ def sync_ddp( class AllGatherGrad(torch.autograd.Function): @staticmethod - def forward(ctx, tensor, group=torch.distributed.group.WORLD): + def forward(ctx, tensor, group=group.WORLD): ctx.batch_size = tensor.shape[0] ctx.group = group From f01d800e3da954cdd5c6772787a86edc7cdc79ed Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 8 Dec 2020 01:27:03 -0500 Subject: [PATCH 19/22] windows fix --- pytorch_lightning/utilities/distributed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 345a44922be53..60e6bddff29fd 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -27,7 +27,6 @@ class ReduceOp: SUM = None - class group: WORLD = None From 586bb5062db1b038adfcd69f190bd8ac5f324fbc Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 8 Dec 2020 04:05:59 -0500 Subject: [PATCH 20/22] removed batch from ctx --- pytorch_lightning/utilities/distributed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 60e6bddff29fd..d0534d6764bd2 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -164,7 +164,6 @@ def sync_ddp( class AllGatherGrad(torch.autograd.Function): @staticmethod def forward(ctx, tensor, group=group.WORLD): - ctx.batch_size = tensor.shape[0] ctx.group = group gathered_tensor = [ From b9d182d60a0c60f72746c69f4f533a5d8e44ffe0 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 8 Dec 2020 11:06:59 -0500 Subject: [PATCH 21/22] removed code duplication --- pytorch_lightning/utilities/distributed.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index d0534d6764bd2..9724f05247c00 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -207,12 +207,6 @@ def all_gather_ddp_if_available( if sync_grads: return AllGatherGrad.apply(tensor, group) else: - gathered_tensor = [ - torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) - ] - - torch.distributed.all_gather(gathered_tensor, tensor) - gathered_tensor = torch.stack(gathered_tensor, dim=0) - - return gathered_tensor + with torch.no_grad: + return AllGatherGrad.apply(tensor, group) return tensor From f2161eeae30023edac91d5729a71084207362590 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 8 Dec 2020 11:10:08 -0500 Subject: [PATCH 22/22] merge --- pytorch_lightning/utilities/distributed.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 36073fcba1f46..9724f05247c00 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -207,17 +207,6 @@ def all_gather_ddp_if_available( if sync_grads: return AllGatherGrad.apply(tensor, group) else: -<<<<<<< HEAD with torch.no_grad: return AllGatherGrad.apply(tensor, group) -======= - gathered_tensor = [ - torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) - ] - - torch.distributed.all_gather(gathered_tensor, tensor) - gathered_tensor = torch.stack(gathered_tensor, dim=0) - - return gathered_tensor ->>>>>>> b3a50c3df0c04143fa35052946a70bd17984a77d return tensor