Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

All gatherwith grads #5012

Merged
merged 28 commits into from
Dec 8, 2020
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## Unreleased

### Added

- Added `all_gather` method to `LightningModule` which allows gradient based tensor synchronizations for use-cases such as negative sampling. ([#5012](https://github.com/PyTorchLightning/pytorch-lightning/pull/5012))

### Fixed

- Fixed `LoggerConnector` to have logged metrics on root device in DP ([#4138](https://github.com/PyTorchLightning/pytorch-lightning/pull/4138))
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,20 @@ def sync_tensor(self,
"""
raise NotImplementedError()

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes

Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op

Return:
A tensor of shape (world_size, batch, ...)
"""
raise NotImplementedError()
Borda marked this conversation as resolved.
Show resolved Hide resolved

def optimizer_state(self, optimizer: Optimizer) -> dict:
"""
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
Expand Down
16 changes: 15 additions & 1 deletion pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available

if HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
Expand Down Expand Up @@ -217,5 +217,19 @@ def sync_tensor(self,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes

Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op

Return:
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)
15 changes: 15 additions & 0 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
Expand Down Expand Up @@ -315,5 +316,19 @@ def sync_tensor(self,
"""
return sync_ddp_if_available(tensor, group, reduce_op)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes

Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op

Return:
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)
15 changes: 15 additions & 0 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
rank_zero_only,
rank_zero_warn,
sync_ddp_if_available,
all_gather_ddp_if_available,
)

if HYDRA_AVAILABLE:
Expand Down Expand Up @@ -249,5 +250,19 @@ def sync_tensor(self,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes

Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op

Return:
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)
16 changes: 15 additions & 1 deletion pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available

if HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
Expand Down Expand Up @@ -211,5 +211,19 @@ def sync_tensor(self,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes

Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op

Return:
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)
15 changes: 15 additions & 0 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
rank_zero_only,
rank_zero_warn,
sync_ddp_if_available,
all_gather_ddp_if_available,
)
from pytorch_lightning.utilities.seed import seed_everything

Expand Down Expand Up @@ -276,5 +277,19 @@ def sync_tensor(self,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes

Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op

Return:
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)
18 changes: 18 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,24 @@ def __auto_choose_log_on_epoch(self, on_epoch):

return on_epoch

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
r"""
Allows users to call ``self.all_gather()`` from the LightningModule, thus making
the ```all_gather``` operation accelerator agnostic.

```all_gather``` is a function provided by accelerators to gather a tensor from several
distributed processes

Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op

Return:
A tensor of shape (world_size, batch, ...)
"""
return self.trainer.accelerator_backend.all_gather(tensor, group=group, sync_grads=sync_grads)

def forward(self, *args, **kwargs):
r"""
Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.distributed import AllGatherGrad
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable
from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils

Expand Down
61 changes: 61 additions & 0 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@

if torch.distributed.is_available():
from torch.distributed import ReduceOp
from torch.distributed import group
else:
class ReduceOp:
SUM = None

class group:
WORLD = None


def rank_zero_only(fn):

Expand Down Expand Up @@ -155,3 +159,60 @@ def sync_ddp(
result = result / torch.distributed.get_world_size(group)

return result


class AllGatherGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, group=group.WORLD):
ctx.group = group

gathered_tensor = [
ananyahjha93 marked this conversation as resolved.
Show resolved Hide resolved
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
]

torch.distributed.all_gather(gathered_tensor, tensor, group=group)
gathered_tensor = torch.stack(gathered_tensor, dim=0)

return gathered_tensor

@staticmethod
def backward(ctx, *grad_output):
grad_output = torch.cat(grad_output)

torch.distributed.all_reduce(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just use reduce_scatter? Using a full all_reduce is unnecessary.

grad_output,
op=torch.distributed.ReduceOp.SUM,
async_op=False,
group=ctx.group
)

return grad_output[torch.distributed.get_rank()]


def all_gather_ddp_if_available(
tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False
) -> torch.Tensor:
"""
Function to gather a tensor from several distributed processes

Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op

Return:
A tensor of shape (world_size, batch, ...)
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
if sync_grads:
return AllGatherGrad.apply(tensor, group)
ananyahjha93 marked this conversation as resolved.
Show resolved Hide resolved
else:
gathered_tensor = [
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
]

torch.distributed.all_gather(gathered_tensor, tensor)
gathered_tensor = torch.stack(gathered_tensor, dim=0)

return gathered_tensor
return tensor
44 changes: 44 additions & 0 deletions tests/utilities/test_all_gather_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
import pytest
import sys
import torch
import torch.nn as nn

from pytorch_lightning.utilities import AllGatherGrad


def setup_ddp(rank, world_size):
""" Setup ddp enviroment """
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "8088"

if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)


def _test_all_gather_ddp(rank, world_size):
setup_ddp(rank, world_size)

tensor1 = torch.ones(8, requires_grad=True)
tensor2 = torch.ones((8, 16, 32), requires_grad=True)

tensor1_gathered = AllGatherGrad.apply(tensor1)
tensor2_gathered = AllGatherGrad.apply(tensor2)

tensor1_gathered = tensor1_gathered * rank
tensor2_gathered = tensor2_gathered * rank

tensor1_gathered.sum().backward()
tensor2_gathered.sum().backward()

grad1 = torch.zeros_like(tensor1.grad).fill_(torch.arange(world_size).sum().float())
grad2 = torch.zeros_like(tensor2.grad).fill_(torch.arange(world_size).sum().float())

assert torch.allclose(grad1, tensor1.grad)
assert torch.allclose(grad2, tensor2.grad)


@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_all_gather_ddp():
world_size = 3
torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size)