Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create barrier without timeout in prepare_data() #19448

Merged
merged 23 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Rename `lightning run model` to `fabric run model` ([#19442](https://github.com/Lightning-AI/pytorch-lightning/pull/19442))
- Renamed `lightning run model` to `fabric run model` ([#19442](https://github.com/Lightning-AI/pytorch-lightning/pull/19442))

-

- The `Fabric.rank_zero_first` context manager now uses a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448))

-

Expand Down
14 changes: 7 additions & 7 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
has_iterable_dataset,
)
from lightning.fabric.utilities.device_dtype_mixin import _update_properties
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper, _InfiniteBarrier
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from lightning.fabric.utilities.registry import _load_external_callbacks
Expand Down Expand Up @@ -632,12 +632,12 @@ def rank_zero_first(self, local: bool = False) -> Generator:

"""
rank = self.local_rank if local else self.global_rank
if rank > 0:
self.barrier()
yield
if rank == 0:
self.barrier()
self.barrier()
with _InfiniteBarrier() as barrier:
if rank > 0:
barrier()
yield
if rank == 0:
barrier()

def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> ContextManager:
r"""Skip gradient synchronization during backward to avoid redundant communication overhead.
Expand Down
32 changes: 31 additions & 1 deletion src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import time
from contextlib import nullcontext
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union

Expand All @@ -11,7 +12,7 @@
from lightning_utilities.core.imports import package_available
from torch import Tensor
from torch.utils.data import Dataset, DistributedSampler, Sampler
from typing_extensions import override
from typing_extensions import Self, override

from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
from lightning.fabric.utilities.data import _num_cpus_available
Expand Down Expand Up @@ -383,3 +384,32 @@ def _distributed_is_initialized() -> bool:
# https://github.com/pytorch/pytorch/blob/v2.1.0/torch/distributed/__init__.py#L25
# this might happen to MacOS builds from source (default) or any build from source that sets `USE_DISTRIBUTED=0`
return torch.distributed.is_available() and torch.distributed.is_initialized()


class _InfiniteBarrier:
"""A barrier with an infinite timeout.

Creates a new process group with the GLOO backend with a very high timeout that makes the barrier effectively wait
forever. This is useful in cases where you want to execute a long-running operation on a subset of ranks that should
not be subject to the regular collective timeout.

"""

def __init__(self) -> None:
self.group = None
self.barrier = lambda: None

def __call__(self) -> None:
self.barrier()

def __enter__(self) -> Self:
if _distributed_is_initialized():
# Create a barrier with an 'infinite' timeout (only reliably possible over the GLOO backend)
self.group = torch.distributed.new_group(backend="gloo", timeout=timedelta(days=10000))
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.barrier = self.group.monitored_barrier
return self

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.barrier()
if self.group is not None:
torch.distributed.destroy_process_group(self.group)
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- The `prepare_data()` hook in `LightningModule` and `LightningDataModule` is now subject to a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448))

-

Expand Down
23 changes: 12 additions & 11 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
has_iterable_dataset,
suggested_max_num_workers,
)
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper, _InfiniteBarrier
from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSamplerWrapper
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
Expand Down Expand Up @@ -87,17 +87,18 @@ def prepare_data(self) -> None:
datamodule = trainer.datamodule
lightning_module = trainer.lightning_module
# handle datamodule prepare data:
# check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data
if datamodule is not None:
dm_prepare_data_per_node = datamodule.prepare_data_per_node
if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero):
call._call_lightning_datamodule_hook(trainer, "prepare_data")
if datamodule is not None and is_overridden("prepare_data", datamodule):
prepare_data_per_node = datamodule.prepare_data_per_node
with _InfiniteBarrier():
if (prepare_data_per_node and local_rank_zero) or (not prepare_data_per_node and global_rank_zero):
call._call_lightning_datamodule_hook(trainer, "prepare_data")

# handle lightning module prepare data:
# check for prepare_data_per_node before calling lightning_module.prepare_data
if lightning_module is not None:
lm_prepare_data_per_node = lightning_module.prepare_data_per_node
if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero):
call._call_lightning_module_hook(trainer, "prepare_data")
if lightning_module is not None and is_overridden("prepare_data", lightning_module):
prepare_data_per_node = lightning_module.prepare_data_per_node
with _InfiniteBarrier():
if (prepare_data_per_node and local_rank_zero) or (not prepare_data_per_node and global_rank_zero):
call._call_lightning_module_hook(trainer, "prepare_data")

def attach_data(
self,
Expand Down
10 changes: 6 additions & 4 deletions tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest import mock
from unittest.mock import ANY, MagicMock, Mock, PropertyMock, call

import lightning.fabric
import pytest
import torch
import torch.distributed
Expand Down Expand Up @@ -1129,24 +1130,25 @@ def test_all_reduce():
fabric._strategy.all_reduce.assert_has_calls([call(torch.tensor(4), **defaults), call(torch.tensor(5), **defaults)])


def test_rank_zero_first():
def test_rank_zero_first(monkeypatch):
"""Test that rank 0 completes first before all other processes can execute under `.rank_zero_first()`."""

def record_calls_for_rank(rank):
call_order = []

fabric = Fabric()
fabric._strategy = Mock(global_rank=rank)
fabric.barrier = Mock(side_effect=lambda *_: call_order.append("barrier"))
barrier_mock = MagicMock(side_effect=lambda *_: call_order.append("barrier"))
monkeypatch.setattr(lightning.fabric.utilities.distributed._InfiniteBarrier, "__call__", barrier_mock)
target = Mock(run=Mock(side_effect=lambda *_: call_order.append("run")))

with fabric.rank_zero_first():
target.run()

return call_order

assert record_calls_for_rank(0) == ["run", "barrier", "barrier"]
assert record_calls_for_rank(1) == ["barrier", "run", "barrier"]
assert record_calls_for_rank(0) == ["run", "barrier"]
assert record_calls_for_rank(1) == ["barrier", "run"]


@pytest.mark.parametrize(("clip_val", "max_norm"), [(1e-3, None), (None, 1)])
Expand Down
29 changes: 29 additions & 0 deletions tests/tests_fabric/utilities/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning.fabric.utilities.distributed import (
_gather_all_tensors,
_InfiniteBarrier,
_set_num_threads_if_needed,
_suggested_max_num_threads,
_sync_ddp,
Expand Down Expand Up @@ -196,3 +197,31 @@ def test_set_num_threads_if_needed(_, set_num_threads_mock, num_processes, expec
_set_num_threads_if_needed(1)
set_num_threads_mock.assert_not_called()
assert os.environ["OMP_NUM_THREADS"] == str(expected)


def test_infinite_barrier():
# distributed not available
barrier = _InfiniteBarrier()
assert barrier.group is None
with mock.patch("lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=False):
barrier.__enter__()
assert barrier.group is None
barrier()
barrier.__exit__(None, None, None)
assert barrier.group is None

# distributed available
barrier = _InfiniteBarrier()
with (
mock.patch("lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True),
mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock
):
barrier.__enter__()
dist_mock.new_group.assert_called_once()
assert barrier.barrier == barrier.group.monitored_barrier
assert barrier.barrier.call_count == 0
barrier()
assert barrier.barrier.call_count == 1
barrier.__exit__(None, None, None)
assert barrier.barrier.call_count == 2
dist_mock.destroy_process_group.assert_called_once()
18 changes: 14 additions & 4 deletions tests/tests_pytorch/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@
@mock.patch("lightning.pytorch.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock)
@mock.patch("lightning.pytorch.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
def test_can_prepare_data(local_rank, node_rank):
dm = Mock(spec=LightningDataModule)
class MyDataModule(LightningDataModule):
def prepare_data(self):
pass

dm = MyDataModule()
dm.prepare_data = Mock(wraps=dm.prepare_data)

dm.prepare_data_per_node = True
trainer = Trainer()
trainer.datamodule = dm
Expand All @@ -56,7 +62,7 @@ def test_can_prepare_data(local_rank, node_rank):
dm.prepare_data.assert_called_once()

# local rank = 1 (False)
dm.reset_mock()
dm.prepare_data.reset_mock()
local_rank.return_value = 1
assert trainer.local_rank == 1

Expand All @@ -65,7 +71,7 @@ def test_can_prepare_data(local_rank, node_rank):

# prepare_data_per_node = False (prepare across all nodes)
# global rank = 0 (True)
dm.reset_mock()
dm.prepare_data.reset_mock()
dm.prepare_data_per_node = False
node_rank.return_value = 0
local_rank.return_value = 0
Expand All @@ -74,7 +80,7 @@ def test_can_prepare_data(local_rank, node_rank):
dm.prepare_data.assert_called_once()

# global rank = 1 (False)
dm.reset_mock()
dm.prepare_data.reset_mock()
node_rank.return_value = 1
local_rank.return_value = 0

Expand Down Expand Up @@ -465,6 +471,10 @@ class CustomBoringDataModule(BoringDataModule):
def state_dict(self):
return {"temp": 1}

# override so that it gets called
def prepare_data(self):
pass

model = BoringModel()
dm = CustomBoringDataModule()
trainer = get_trainer()
Expand Down
8 changes: 8 additions & 0 deletions tests/tests_pytorch/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def call(hook, fn, *args, **kwargs):
update_wrapper(partial_h, attr)
setattr(self, h, partial_h)

# override so that it gets called
def prepare_data(self):
...


@pytest.mark.parametrize("max_steps", [1, 2, 3])
def test_on_before_zero_grad_called(tmpdir, max_steps):
Expand Down Expand Up @@ -407,6 +411,10 @@ def on_test_model_train(self):
def on_predict_model_train(self):
...

# override so that it gets called
def prepare_data(self):
...


@pytest.mark.parametrize(
"kwargs",
Expand Down
Loading