Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] Update broadcast + reduce decision ModelCheckpoint] #6410

Merged
merged 70 commits into from
Mar 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
597ae27
resolve bug
tchaton Mar 4, 2021
ef11927
update
tchaton Mar 4, 2021
85b327d
update changelog
tchaton Mar 4, 2021
47f0b2c
update PR
tchaton Mar 4, 2021
bbe4255
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton Mar 4, 2021
1c33b48
Update pytorch_lightning/trainer/connectors/logger_connector/epoch_re…
tchaton Mar 4, 2021
6cd4713
add todo
tchaton Mar 4, 2021
45d7239
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton Mar 4, 2021
b58d7fb
resolve issues
tchaton Mar 4, 2021
e3a084a
resolve flake8
tchaton Mar 4, 2021
77edbed
update
tchaton Mar 4, 2021
6bcc88d
add coverage for reduce
tchaton Mar 4, 2021
c63bca5
wip
tchaton Mar 4, 2021
e26d301
restore back to brodbact
tchaton Mar 4, 2021
ce239fd
remove test.py
tchaton Mar 4, 2021
d8f1dc9
resolve flake8
tchaton Mar 4, 2021
237bbd2
update
tchaton Mar 4, 2021
f546ae4
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton Mar 4, 2021
6fbe70d
check world size
tchaton Mar 4, 2021
5f25fc5
resolve test
tchaton Mar 4, 2021
46cf2c6
update
tchaton Mar 4, 2021
7029b31
Merge branch 'master' into bugfix/5604_ddp_model_checkpoint
tchaton Mar 5, 2021
8523167
use pytorch version when defined
tchaton Mar 5, 2021
f28f950
update on comments
tchaton Mar 5, 2021
6eae79d
update on comments
tchaton Mar 5, 2021
1cd9431
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton Mar 5, 2021
9448964
flake8
tchaton Mar 5, 2021
1b5c90a
resolve bugs
tchaton Mar 5, 2021
a1264d9
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton Mar 5, 2021
9f3eb41
Merge branch 'master' into bugfix/5604_ddp_model_checkpoint
tchaton Mar 5, 2021
e88ef07
Merge branch 'master' into bugfix/5604_ddp_model_checkpoint
tchaton Mar 5, 2021
c21f148
Update CHANGELOG.md
tchaton Mar 5, 2021
94e9aa9
update
tchaton Mar 5, 2021
4626310
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton Mar 5, 2021
b260bf6
update
tchaton Mar 5, 2021
dd60ed1
update
tchaton Mar 5, 2021
45b65f1
update
tchaton Mar 6, 2021
dcd6884
remove test
tchaton Mar 6, 2021
2e046e8
update
tchaton Mar 6, 2021
68ffb5b
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton Mar 6, 2021
23b2c10
resolve flake8
tchaton Mar 6, 2021
b4c663b
update
tchaton Mar 6, 2021
aa89d5d
Merge branch 'master' into bugfix/5604_ddp_model_checkpoint
tchaton Mar 6, 2021
73e83f7
update
tchaton Mar 6, 2021
c060444
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton Mar 6, 2021
2eb6db4
update
tchaton Mar 6, 2021
060992b
proxy
tchaton Mar 6, 2021
5bad135
update
tchaton Mar 6, 2021
4579842
update
tchaton Mar 6, 2021
5276cd0
Merge branch 'master' into bugfix/broadcast_2
tchaton Mar 8, 2021
8027838
resolve typo
tchaton Mar 8, 2021
aa9a6ca
prune
tchaton Mar 8, 2021
4b6a6c5
update parallel
tchaton Mar 8, 2021
4b55c52
update
tchaton Mar 8, 2021
cbacf48
update changelog
tchaton Mar 8, 2021
057fbf3
update
tchaton Mar 9, 2021
7f515ea
Merge branch 'master' into bugfix/broadcast_2
tchaton Mar 9, 2021
7cbf38b
try running pipe
tchaton Mar 9, 2021
928cf2c
Merge branch 'master' into bugfix/broadcast_2
carmocca Mar 10, 2021
690b61f
Update pytorch_lightning/callbacks/model_checkpoint.py
tchaton Mar 10, 2021
300a632
update on comments
tchaton Mar 10, 2021
5e30377
Update pytorch_lightning/callbacks/model_checkpoint.py
tchaton Mar 11, 2021
015fbac
update on comennts
tchaton Mar 12, 2021
c213716
Merge branch 'bugfix/broadcast_2' of https://github.com/PyTorchLightn…
tchaton Mar 12, 2021
f668c3a
update
tchaton Mar 12, 2021
30feb40
update
tchaton Mar 12, 2021
a4bf623
update
tchaton Mar 12, 2021
b482589
fix
tchaton Mar 12, 2021
1ad9c62
remove comments
tchaton Mar 12, 2021
b64e105
resolve bugs
tchaton Mar 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688))


- Fixed broacast to use PyTorch `broadcast_object_list` and add `reduce_decision` ([#6410](https://github.com/PyTorchLightning/pytorch-lightning/pull/6410))


- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380))


Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.enums import AMPType, LightningEnum

if TYPE_CHECKING:
Expand Down Expand Up @@ -405,7 +404,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
Return:
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads)

def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Wraps the dataloader if necessary
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,4 @@ def _run_early_stopping_check(self, trainer):
trainer.should_stop = True

# stop every ddp process if any world process decides to stop
trainer.should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(trainer.should_stop)
trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop)
26 changes: 10 additions & 16 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def _do_save(self, trainer, filepath: str):
else:
raise ValueError(".save_function() not set")

def check_monitor_top_k(self, current: torch.Tensor) -> bool:
def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) -> bool:
if current is None:
return False

Expand All @@ -346,7 +346,12 @@ def check_monitor_top_k(self, current: torch.Tensor) -> bool:
current = torch.tensor(current)

monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode]
return monitor_op(current, self.best_k_models[self.kth_best_model_path]).item()
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])

# If using multiple devices, make sure all processes are unanimous on the decision.
should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save)

return should_update_best_and_save

@classmethod
def _format_checkpoint_name(
Expand Down Expand Up @@ -528,15 +533,7 @@ def _save_top_k_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]):
epoch = monitor_candidates.get("epoch")
step = monitor_candidates.get("step")

# when `val_loss` is being logged and no ModelCheckpoint is being provided
# `val_loss` will be selected for monitor and need to be reduced to
# prevent processes divergence
# TODO: Move this logic to logger_connector. This also needs to be fixed for any
# other monitor logged value which aren't produced from a Metric.
if self.monitor == "val_loss":
current = trainer.training_type_plugin.reduce(current, reduce_op="mean")

if self.check_monitor_top_k(current):
if self.check_monitor_top_k(trainer, current):
self._update_best_and_save(current, epoch, step, trainer, monitor_candidates)
elif self.verbose:
rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}")
Expand All @@ -554,9 +551,7 @@ def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, A
self._save_model(trainer, filepath)

if (
self.save_top_k is None
and self.best_model_path
and self.best_model_path != filepath
self.save_top_k is None and self.best_model_path and self.best_model_path != filepath
and trainer.is_global_zero
):
self._del_model(self.best_model_path)
Expand Down Expand Up @@ -623,5 +618,4 @@ def file_exists(self, filepath: Union[str, Path], trainer) -> bool:
the internal state to diverge between ranks.
"""
exists = self._fs.exists(filepath)
exists = trainer.training_type_plugin.broadcast(exists)
return exists
return trainer.training_type_plugin.broadcast(exists)
51 changes: 12 additions & 39 deletions pytorch_lightning/distributed/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from typing import Any

import torch
from torch import distributed as torch_distrib

from pytorch_lightning.utilities import _GROUP_AVAILABLE

WORLD = None
if _GROUP_AVAILABLE:
from torch.distributed import group
WORLD = group.WORLD
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
from pytorch_lightning.utilities.distributed import group as _group


class LightningDistributed:
Expand All @@ -31,32 +23,13 @@ def __init__(self, rank=None, device=None):
self.rank = rank
self.device = device

def broadcast(self, obj: Any, group=WORLD):
if self.rank == 0:
self._emit(obj, group)
else:
obj = self._receive(group)
return obj

def _broadcast(self, tensor, src=0, group=WORLD):
if group is None:
return torch_distrib.broadcast(tensor, src=src)
return torch_distrib.broadcast(tensor, src=0, group=group)

def _emit(self, obj: Any, group=WORLD):
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
length_tensor = torch.tensor([len(data)]).long().to(self.device)
self._broadcast(length_tensor, src=0, group=group)
data_tensor = torch.ByteTensor(data).to(self.device)
self._broadcast(data_tensor, src=0, group=group)

def _receive(self, group=WORLD):
length_tensor = torch.tensor([0]).long().to(self.device)
self._broadcast(length_tensor, src=0, group=group)
data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device)
self._broadcast(data_tensor, src=0, group=group)
buffer = io.BytesIO(data_tensor.cpu().numpy())
obj = torch.load(buffer)
return obj
tchaton marked this conversation as resolved.
Show resolved Hide resolved
def broadcast(self, obj: Any, group=_group.WORLD):
# always wrap into a list so list can be brodcasted.
obj = [obj]

if self.rank != 0:
obj = [None] * len(obj)

broadcast_object_list(obj, 0, group=group or _group.WORLD)

return obj[0]
94 changes: 94 additions & 0 deletions pytorch_lightning/overrides/torch_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import logging
import pickle

import torch

from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7

log = logging.getLogger(__name__)

if torch.distributed.is_available():
from torch.distributed import Backend, broadcast, get_backend, get_rank, GroupMember

# The code underneath is taken from PyTorch ``torch/distributed/distributed_c10d.py``
# and enable broadcasting for PyTorch 1.6 and lower.


# https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L160
def _rank_not_in_group(group):
"""
Helper that checks if the current process's rank is not in a given group.
"""
if group is None:
return False
return group == GroupMember.NON_GROUP_MEMBER


# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1164
def _object_to_tensor(obj):
buffer = pickle.dumps(obj)
byte_storage = torch.ByteStorage.from_buffer(buffer) # type: ignore[attr-defined]
byte_tensor = torch.ByteTensor(byte_storage)
local_size = torch.LongTensor([byte_tensor.numel()])
return byte_tensor, local_size


# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py
def _tensor_to_object(tensor, tensor_size):
buf = tensor.numpy().tobytes()[:tensor_size]
out = pickle.loads(buf)
return out


# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327
def _broadcast_object_list(object_list, src=0, group=None):
if _rank_not_in_group(group):
return

my_rank = get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list])
object_sizes_tensor = torch.cat(size_list)
else:
object_sizes_tensor = torch.LongTensor(len(object_list))

group_backend = get_backend(group)
is_nccl_backend = group_backend == Backend.NCCL
current_device = torch.device("cpu")
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device('cuda', torch.cuda.current_device())
object_sizes_tensor = object_sizes_tensor.to(current_device)
object_sizes_tensor = object_sizes_tensor.to(current_device)

# Broadcast object sizes
broadcast(object_sizes_tensor, src=src, group=group)

# Concatenate and broadcast serialized object tensors
if my_rank == src:
object_tensor = torch.cat(tensor_list)
else:
object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item())

if is_nccl_backend:
object_tensor = object_tensor.to(current_device)

broadcast(object_tensor, src=src, group=group)

# Deserialize objects using their stored sizes.
offset = 0
if my_rank != src:
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset:offset + obj_size]
obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload]
offset += obj_size
object_list[i] = _tensor_to_object(obj_view, obj_size)


if _TORCH_GREATER_EQUAL_1_7 and torch.distributed.is_available():
from torch.distributed.distributed_c10d import broadcast_object_list
else:
broadcast_object_list = _broadcast_object_list
1 change: 0 additions & 1 deletion pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,4 @@ def pre_optimizer_step(
pl_module.trainer.call_hook("on_after_backward")

optimizer.step(**kwargs)

return False
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def barrier(self, *args, **kwargs):
def broadcast(self, obj: object, src: int = 0) -> object:
return obj

def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
return should_stop
def reduce_boolean_decision(self, decision: bool) -> bool:
return decision

def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
Expand Down
11 changes: 8 additions & 3 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.distributed import group, rank_zero_only, ReduceOp

if _HOROVOD_AVAILABLE:
import horovod.torch as hvd
Expand Down Expand Up @@ -159,8 +159,13 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[
hvd.join()
return hvd.allreduce(tensor, op=reduce_op)

def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None):
if group is not None:
def all_gather(
self,
result: Union[torch.Tensor],
group: Optional[Any] = group.WORLD,
sync_grads: bool = False
) -> torch.Tensor:
if group is not None and group != group.WORLD:
raise ValueError(
"Horovod does not support allgather using a subcommunicator at this time. "
"Unset `group`."
Expand Down
30 changes: 12 additions & 18 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import List, Optional
from typing import Any, List, Optional

import torch
from torch.nn.parallel import DistributedDataParallel
Expand All @@ -36,9 +35,10 @@ def __init__(
):
super().__init__()
self.parallel_devices = parallel_devices
self.cluster_environment = cluster_environment
self.global_rank = 0
self.world_size = 1
self.local_rank = 0
self.cluster_environment = cluster_environment

@property
@abstractmethod
Expand Down Expand Up @@ -70,11 +70,15 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=len(self.parallel_devices), rank=self.global_rank)
return distributed_sampler_kwargs

def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
should_stop = torch.tensor(int(should_stop), device=self.lightning_module.device)
should_stop = self.reduce(should_stop, reduce_op=ReduceOp.SUM)
should_stop = bool(should_stop == self.world_size)
return should_stop
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""Perform a all_gather on all processes """
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def reduce_boolean_decision(self, decision: bool) -> bool:
decision = torch.tensor(int(decision), device=self.lightning_module.device)
decision = self.reduce(decision, reduce_op=ReduceOp.SUM)
decision = bool(decision == self.world_size)
return decision

@property
def torch_distributed_backend(self):
Expand Down Expand Up @@ -112,13 +116,3 @@ def block_backward_sync(self):
yield None
else:
yield None

def broadcast(self, obj: object, src: int) -> object:
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
data_tensor = torch.tensor(data).to(self.root_device, dtype=torch.float)
data = all_gather_ddp_if_available(data_tensor)
buffer = io.BytesIO(data.cpu().byte().numpy())
obj = torch.load(buffer)
return obj
tchaton marked this conversation as resolved.
Show resolved Hide resolved
9 changes: 8 additions & 1 deletion pytorch_lightning/plugins/training_type/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Union
from typing import Any, Optional, Union

import torch

Expand All @@ -23,6 +23,9 @@ class SingleDevicePlugin(TrainingTypePlugin):
def __init__(self, device: torch.device):
super().__init__()
self.device: torch.device = device
self.global_rank = 0
self.local_rank = 0
self.world_size = 1

@property
def on_tpu(self) -> bool:
Expand All @@ -47,6 +50,10 @@ def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) ->
"""
return tensor

def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""Perform a all_gather on all processes """
return tensor

@property
def root_device(self) -> torch.device:
return self.device
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,11 @@ def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
model.trainer.save_checkpoint(path)
return path

def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
should_stop = torch.tensor(int(should_stop), device=self.lightning_module.device)
stop = xm.mesh_reduce('stop_signal', should_stop, sum)
rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
should_stop = int(stop.item()) == self.world_size
return should_stop
def reduce_decision(self, decision: bool) -> bool:
decision = torch.tensor(int(decision), device=self.device)
decision = self.reduce(decision, "sum")
decision = bool(decision == self.world_size)
return decision

def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
if not isinstance(output, torch.Tensor):
Expand Down
Loading