Skip to content

Commit

Permalink
[bug] Update broadcast + reduce decision ModelCheckpoint] (#6410)
Browse files Browse the repository at this point in the history
* resolve bug

* update

* update changelog

* update PR

* Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* add todo

* resolve issues

* resolve flake8

* update

* add coverage for reduce

* wip

* restore back to brodbact

* remove test.py

* resolve flake8

* update

* check world size

* resolve test

* update

* use pytorch version when defined

* update on comments

* update on comments

* flake8

* resolve bugs

* Update CHANGELOG.md

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* update

* update

* update

* update

* remove test

* update

* resolve flake8

* update

* update

* update

* proxy

* update

* update

* resolve typo

* prune

* update parallel

* update

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

(cherry picked from commit 0544efd)
  • Loading branch information
tchaton authored and lexierule committed Mar 16, 2021
1 parent 4b762a9 commit 6bb24c2
Show file tree
Hide file tree
Showing 21 changed files with 345 additions and 153 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688))


- Fixed broacast to use PyTorch `broadcast_object_list` and add `reduce_decision` ([#6410](https://github.com/PyTorchLightning/pytorch-lightning/pull/6410))


- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380))


Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.enums import AMPType, LightningEnum


Expand Down Expand Up @@ -396,7 +395,7 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s
Return:
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads)

def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Wraps the dataloader if necessary
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,4 @@ def _run_early_stopping_check(self, trainer, pl_module):
trainer.should_stop = True

# stop every ddp process if any world process decides to stop
trainer.should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(trainer.should_stop)
trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop)
22 changes: 9 additions & 13 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def _save_model(self, filepath: str, trainer, pl_module):
else:
raise ValueError(".save_function() not set")

def check_monitor_top_k(self, current) -> bool:
def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) -> bool:
if current is None:
return False

Expand All @@ -356,7 +356,12 @@ def check_monitor_top_k(self, current) -> bool:
current = torch.tensor(current)

monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode]
return monitor_op(current, self.best_k_models[self.kth_best_model_path]).item()
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])

# If using multiple devices, make sure all processes are unanimous on the decision.
should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save)

return should_update_best_and_save

@classmethod
def _format_checkpoint_name(
Expand Down Expand Up @@ -554,15 +559,7 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
epoch = metrics.get("epoch")
step = metrics.get("step")

# when `val_loss` is being logged and no ModelCheckpoint is being provided
# `val_loss` will be selected for monitor and need to be reduced to
# prevent processes divergence
# TODO: Move this logic to logger_connector. This also needs to be fixed for any
# other monitor logged value which aren't produced from a Metric.
if self.monitor == "val_loss":
current = trainer.training_type_plugin.reduce(current, reduce_op="mean")

if self.check_monitor_top_k(current):
if self.check_monitor_top_k(trainer, current):
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
elif self.verbose:
rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}")
Expand Down Expand Up @@ -627,5 +624,4 @@ def file_exists(self, filepath: Union[str, Path], trainer) -> bool:
the internal state to diverge between ranks.
"""
exists = self._fs.exists(filepath)
exists = trainer.training_type_plugin.broadcast(exists)
return exists
return trainer.training_type_plugin.broadcast(exists)
51 changes: 12 additions & 39 deletions pytorch_lightning/distributed/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from typing import Any

import torch
from torch import distributed as torch_distrib

from pytorch_lightning.utilities import _GROUP_AVAILABLE

WORLD = None
if _GROUP_AVAILABLE:
from torch.distributed import group
WORLD = group.WORLD
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
from pytorch_lightning.utilities.distributed import group as _group


class LightningDistributed:
Expand All @@ -31,32 +23,13 @@ def __init__(self, rank=None, device=None):
self.rank = rank
self.device = device

def broadcast(self, obj: Any, group=WORLD):
if self.rank == 0:
self._emit(obj, group)
else:
obj = self._receive(group)
return obj

def _broadcast(self, tensor, src=0, group=WORLD):
if group is None:
return torch_distrib.broadcast(tensor, src=src)
return torch_distrib.broadcast(tensor, src=0, group=group)

def _emit(self, obj: Any, group=WORLD):
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
length_tensor = torch.tensor([len(data)]).long().to(self.device)
self._broadcast(length_tensor, src=0, group=group)
data_tensor = torch.ByteTensor(data).to(self.device)
self._broadcast(data_tensor, src=0, group=group)

def _receive(self, group=WORLD):
length_tensor = torch.tensor([0]).long().to(self.device)
self._broadcast(length_tensor, src=0, group=group)
data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device)
self._broadcast(data_tensor, src=0, group=group)
buffer = io.BytesIO(data_tensor.cpu().numpy())
obj = torch.load(buffer)
return obj
def broadcast(self, obj: Any, group=_group.WORLD):
# always wrap into a list so list can be brodcasted.
obj = [obj]

if self.rank != 0:
obj = [None] * len(obj)

broadcast_object_list(obj, 0, group=group or _group.WORLD)

return obj[0]
94 changes: 94 additions & 0 deletions pytorch_lightning/overrides/torch_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import logging
import pickle

import torch

from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7

log = logging.getLogger(__name__)

if torch.distributed.is_available():
from torch.distributed import Backend, broadcast, get_backend, get_rank, GroupMember

# The code underneath is taken from PyTorch ``torch/distributed/distributed_c10d.py``
# and enable broadcasting for PyTorch 1.6 and lower.


# https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L160
def _rank_not_in_group(group):
"""
Helper that checks if the current process's rank is not in a given group.
"""
if group is None:
return False
return group == GroupMember.NON_GROUP_MEMBER


# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1164
def _object_to_tensor(obj):
buffer = pickle.dumps(obj)
byte_storage = torch.ByteStorage.from_buffer(buffer) # type: ignore[attr-defined]
byte_tensor = torch.ByteTensor(byte_storage)
local_size = torch.LongTensor([byte_tensor.numel()])
return byte_tensor, local_size


# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py
def _tensor_to_object(tensor, tensor_size):
buf = tensor.numpy().tobytes()[:tensor_size]
out = pickle.loads(buf)
return out


# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327
def _broadcast_object_list(object_list, src=0, group=None):
if _rank_not_in_group(group):
return

my_rank = get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list])
object_sizes_tensor = torch.cat(size_list)
else:
object_sizes_tensor = torch.LongTensor(len(object_list))

group_backend = get_backend(group)
is_nccl_backend = group_backend == Backend.NCCL
current_device = torch.device("cpu")
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device('cuda', torch.cuda.current_device())
object_sizes_tensor = object_sizes_tensor.to(current_device)
object_sizes_tensor = object_sizes_tensor.to(current_device)

# Broadcast object sizes
broadcast(object_sizes_tensor, src=src, group=group)

# Concatenate and broadcast serialized object tensors
if my_rank == src:
object_tensor = torch.cat(tensor_list)
else:
object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item())

if is_nccl_backend:
object_tensor = object_tensor.to(current_device)

broadcast(object_tensor, src=src, group=group)

# Deserialize objects using their stored sizes.
offset = 0
if my_rank != src:
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset:offset + obj_size]
obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload]
offset += obj_size
object_list[i] = _tensor_to_object(obj_view, obj_size)


if _TORCH_GREATER_EQUAL_1_7 and torch.distributed.is_available():
from torch.distributed.distributed_c10d import broadcast_object_list
else:
broadcast_object_list = _broadcast_object_list
41 changes: 24 additions & 17 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Tuple
from typing import Any, Callable, Generator, List, Sequence, Tuple, Type, TYPE_CHECKING

import torch
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
Expand All @@ -23,37 +22,41 @@
if _APEX_AVAILABLE:
from apex import amp

if TYPE_CHECKING:
from torch.optim import Optimizer


class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""

def __init__(self, amp_level: str):
def __init__(self, amp_level: str = "O2") -> None:
self.backend = AMPType.APEX
self.amp_level = amp_level

def master_params(self, optimizer: torch.optim.Optimizer):
def master_params(self, optimizer: 'Optimizer') -> Generator[torch.Tensor, None, None]:
return amp.master_params(optimizer)

def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
def connect(self, model: torch.nn.Module, optimizers: Sequence['Optimizer'],
lr_schedulers: Sequence[Any]) -> Tuple[torch.nn.Module, Sequence['Optimizer'], Sequence[Any]]:
"""Connects the precision plugin to the training process,
configures apex and reinits the schedulers
"""
if model.device.type != "cuda":
return model, optimizers, lr_schedulers
model, optimizers = self.configure_apex(amp, model, optimizers, self.amp_level)
model, optimizers = self.configure_apex(amp, model, list(optimizers), self.amp_level)
self.reinit_scheduler_properties(optimizers, lr_schedulers)
return model, optimizers, lr_schedulers

def backward(
self,
model: LightningModule,
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
optimizer: 'Optimizer',
opt_idx: int,
should_accumulate: bool,
*args,
**kwargs,
):
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
"""performs the actual backpropagation
Args:
Expand Down Expand Up @@ -94,11 +97,11 @@ def backward(

def configure_apex(
self,
amp: object,
amp: Type,
model: LightningModule,
optimizers: List[Optimizer],
optimizers: List['Optimizer'],
amp_level: str,
) -> Tuple[LightningModule, List[Optimizer]]:
) -> Tuple[LightningModule, List['Optimizer']]:
r"""
Override to init AMP your own way.
Must return a model and list of optimizers.
Expand Down Expand Up @@ -127,7 +130,7 @@ def configure_apex(self, amp, model, optimizers, amp_level):
return model, optimizers

@staticmethod
def reinit_scheduler_properties(optimizers: list, schedulers: list):
def reinit_scheduler_properties(optimizers: Sequence['Optimizer'], schedulers: Sequence[Any]) -> None:
"""Reinitializes schedulers with correct properties"""
# Reinitialize optimizer.step properties added by schedulers
for scheduler in schedulers:
Expand All @@ -149,7 +152,12 @@ def reinit_scheduler_properties(optimizers: list, schedulers: list):
break

def pre_optimizer_step(
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs
self,
pl_module: LightningModule,
optimizer: 'Optimizer',
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
"""
always called before the optimizer step.
Expand All @@ -160,6 +168,5 @@ def pre_optimizer_step(
if not pl_module.automatic_optimization:
pl_module.trainer.call_hook("on_after_backward")

optimizer.step()

optimizer.step(**kwargs)
return False
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def barrier(self, *args, **kwargs):
def broadcast(self, obj: object, src: int = 0) -> object:
return obj

def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
return should_stop
def reduce_boolean_decision(self, decision: bool) -> bool:
return decision

def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
Expand Down
11 changes: 8 additions & 3 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.distributed import group, rank_zero_only, ReduceOp

if _HOROVOD_AVAILABLE:
import horovod.torch as hvd
Expand Down Expand Up @@ -147,8 +147,13 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
hvd.join()
return hvd.allreduce(output, op=reduce_op)

def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None):
if group is not None:
def all_gather(
self,
result: Union[torch.Tensor],
group: Optional[Any] = group.WORLD,
sync_grads: bool = False
) -> torch.Tensor:
if group is not None and group != group.WORLD:
raise ValueError(
"Horovod does not support allgather using a subcommunicator at this time. "
"Unset `group`."
Expand Down
Loading

0 comments on commit 6bb24c2

Please sign in to comment.