-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[bug] Update broadcast + reduce decision ModelCheckpoint] (#6410)
* resolve bug * update * update changelog * update PR * Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * add todo * resolve issues * resolve flake8 * update * add coverage for reduce * wip * restore back to brodbact * remove test.py * resolve flake8 * update * check world size * resolve test * update * use pytorch version when defined * update on comments * update on comments * flake8 * resolve bugs * Update CHANGELOG.md Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * update * update * update * update * remove test * update * resolve flake8 * update * update * update * proxy * update * update * resolve typo * prune * update parallel * update Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> (cherry picked from commit 0544efd)
- Loading branch information
Showing
21 changed files
with
345 additions
and
153 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import logging | ||
import pickle | ||
|
||
import torch | ||
|
||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
if torch.distributed.is_available(): | ||
from torch.distributed import Backend, broadcast, get_backend, get_rank, GroupMember | ||
|
||
# The code underneath is taken from PyTorch ``torch/distributed/distributed_c10d.py`` | ||
# and enable broadcasting for PyTorch 1.6 and lower. | ||
|
||
|
||
# https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L160 | ||
def _rank_not_in_group(group): | ||
""" | ||
Helper that checks if the current process's rank is not in a given group. | ||
""" | ||
if group is None: | ||
return False | ||
return group == GroupMember.NON_GROUP_MEMBER | ||
|
||
|
||
# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1164 | ||
def _object_to_tensor(obj): | ||
buffer = pickle.dumps(obj) | ||
byte_storage = torch.ByteStorage.from_buffer(buffer) # type: ignore[attr-defined] | ||
byte_tensor = torch.ByteTensor(byte_storage) | ||
local_size = torch.LongTensor([byte_tensor.numel()]) | ||
return byte_tensor, local_size | ||
|
||
|
||
# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py | ||
def _tensor_to_object(tensor, tensor_size): | ||
buf = tensor.numpy().tobytes()[:tensor_size] | ||
out = pickle.loads(buf) | ||
return out | ||
|
||
|
||
# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327 | ||
def _broadcast_object_list(object_list, src=0, group=None): | ||
if _rank_not_in_group(group): | ||
return | ||
|
||
my_rank = get_rank() | ||
# Serialize object_list elements to tensors on src rank. | ||
if my_rank == src: | ||
tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list]) | ||
object_sizes_tensor = torch.cat(size_list) | ||
else: | ||
object_sizes_tensor = torch.LongTensor(len(object_list)) | ||
|
||
group_backend = get_backend(group) | ||
is_nccl_backend = group_backend == Backend.NCCL | ||
current_device = torch.device("cpu") | ||
if is_nccl_backend: | ||
# See note about using torch.cuda.current_device() here in docstring. | ||
# We cannot simply use my_rank since rank == device is not necessarily | ||
# true. | ||
current_device = torch.device('cuda', torch.cuda.current_device()) | ||
object_sizes_tensor = object_sizes_tensor.to(current_device) | ||
object_sizes_tensor = object_sizes_tensor.to(current_device) | ||
|
||
# Broadcast object sizes | ||
broadcast(object_sizes_tensor, src=src, group=group) | ||
|
||
# Concatenate and broadcast serialized object tensors | ||
if my_rank == src: | ||
object_tensor = torch.cat(tensor_list) | ||
else: | ||
object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item()) | ||
|
||
if is_nccl_backend: | ||
object_tensor = object_tensor.to(current_device) | ||
|
||
broadcast(object_tensor, src=src, group=group) | ||
|
||
# Deserialize objects using their stored sizes. | ||
offset = 0 | ||
if my_rank != src: | ||
for i, obj_size in enumerate(object_sizes_tensor): | ||
obj_view = object_tensor[offset:offset + obj_size] | ||
obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload] | ||
offset += obj_size | ||
object_list[i] = _tensor_to_object(obj_view, obj_size) | ||
|
||
|
||
if _TORCH_GREATER_EQUAL_1_7 and torch.distributed.is_available(): | ||
from torch.distributed.distributed_c10d import broadcast_object_list | ||
else: | ||
broadcast_object_list = _broadcast_object_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.