diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index ab6f0f4971f27..0024c5ae82785 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -17,9 +17,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- Rename `lightning run model` to `fabric run model` ([#19442](https://github.com/Lightning-AI/pytorch-lightning/pull/19442)) +- Renamed `lightning run model` to `fabric run model` ([#19442](https://github.com/Lightning-AI/pytorch-lightning/pull/19442)) -- + +- The `Fabric.rank_zero_first` context manager now uses a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448)) - diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index bc07e633a911e..9bbd7b47a144d 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -65,7 +65,7 @@ has_iterable_dataset, ) from lightning.fabric.utilities.device_dtype_mixin import _update_properties -from lightning.fabric.utilities.distributed import DistributedSamplerWrapper +from lightning.fabric.utilities.distributed import DistributedSamplerWrapper, _InfiniteBarrier from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn from lightning.fabric.utilities.registry import _load_external_callbacks @@ -632,12 +632,12 @@ def rank_zero_first(self, local: bool = False) -> Generator: """ rank = self.local_rank if local else self.global_rank - if rank > 0: - self.barrier() - yield - if rank == 0: - self.barrier() - self.barrier() + with _InfiniteBarrier() as barrier: + if rank > 0: + barrier() + yield + if rank == 0: + barrier() def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> ContextManager: r"""Skip gradient synchronization during backward to avoid redundant communication overhead. diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 16157c66274be..30bfe4e254a07 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -3,6 +3,7 @@ import os import time from contextlib import nullcontext +from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union @@ -11,7 +12,7 @@ from lightning_utilities.core.imports import package_available from torch import Tensor from torch.utils.data import Dataset, DistributedSampler, Sampler -from typing_extensions import override +from typing_extensions import Self, override from lightning.fabric.utilities.cloud_io import _is_local_file_protocol from lightning.fabric.utilities.data import _num_cpus_available @@ -383,3 +384,32 @@ def _distributed_is_initialized() -> bool: # https://github.com/pytorch/pytorch/blob/v2.1.0/torch/distributed/__init__.py#L25 # this might happen to MacOS builds from source (default) or any build from source that sets `USE_DISTRIBUTED=0` return torch.distributed.is_available() and torch.distributed.is_initialized() + + +class _InfiniteBarrier: + """A barrier with an infinite timeout. + + Creates a new process group with the GLOO backend with a very high timeout that makes the barrier effectively wait + forever. This is useful in cases where you want to execute a long-running operation on a subset of ranks that should + not be subject to the regular collective timeout. + + """ + + def __init__(self) -> None: + self.group = None + self.barrier = lambda: None + + def __call__(self) -> None: + self.barrier() + + def __enter__(self) -> Self: + if _distributed_is_initialized(): + # Create a barrier with an 'infinite' timeout (only reliably possible over the GLOO backend) + self.group = torch.distributed.new_group(backend="gloo", timeout=timedelta(days=10000)) + self.barrier = self.group.monitored_barrier + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.barrier() + if self.group is not None: + torch.distributed.destroy_process_group(self.group) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index f0de46f3fda16..4cad518eca057 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- The `prepare_data()` hook in `LightningModule` and `LightningDataModule` is now subject to a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448)) - diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 1bc63c62c561f..8d9b5e82cbc94 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -27,7 +27,7 @@ has_iterable_dataset, suggested_max_num_workers, ) -from lightning.fabric.utilities.distributed import DistributedSamplerWrapper +from lightning.fabric.utilities.distributed import DistributedSamplerWrapper, _InfiniteBarrier from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSamplerWrapper from lightning.pytorch.trainer import call from lightning.pytorch.trainer.states import RunningStage, TrainerFn @@ -87,17 +87,18 @@ def prepare_data(self) -> None: datamodule = trainer.datamodule lightning_module = trainer.lightning_module # handle datamodule prepare data: - # check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data - if datamodule is not None: - dm_prepare_data_per_node = datamodule.prepare_data_per_node - if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero): - call._call_lightning_datamodule_hook(trainer, "prepare_data") + if datamodule is not None and is_overridden("prepare_data", datamodule): + prepare_data_per_node = datamodule.prepare_data_per_node + with _InfiniteBarrier(): + if (prepare_data_per_node and local_rank_zero) or (not prepare_data_per_node and global_rank_zero): + call._call_lightning_datamodule_hook(trainer, "prepare_data") + # handle lightning module prepare data: - # check for prepare_data_per_node before calling lightning_module.prepare_data - if lightning_module is not None: - lm_prepare_data_per_node = lightning_module.prepare_data_per_node - if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero): - call._call_lightning_module_hook(trainer, "prepare_data") + if lightning_module is not None and is_overridden("prepare_data", lightning_module): + prepare_data_per_node = lightning_module.prepare_data_per_node + with _InfiniteBarrier(): + if (prepare_data_per_node and local_rank_zero) or (not prepare_data_per_node and global_rank_zero): + call._call_lightning_module_hook(trainer, "prepare_data") def attach_data( self, diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 2c860eb45d78d..5eb39a0ad4b91 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -16,6 +16,7 @@ from unittest import mock from unittest.mock import ANY, MagicMock, Mock, PropertyMock, call +import lightning.fabric import pytest import torch import torch.distributed @@ -1129,7 +1130,7 @@ def test_all_reduce(): fabric._strategy.all_reduce.assert_has_calls([call(torch.tensor(4), **defaults), call(torch.tensor(5), **defaults)]) -def test_rank_zero_first(): +def test_rank_zero_first(monkeypatch): """Test that rank 0 completes first before all other processes can execute under `.rank_zero_first()`.""" def record_calls_for_rank(rank): @@ -1137,7 +1138,8 @@ def record_calls_for_rank(rank): fabric = Fabric() fabric._strategy = Mock(global_rank=rank) - fabric.barrier = Mock(side_effect=lambda *_: call_order.append("barrier")) + barrier_mock = MagicMock(side_effect=lambda *_: call_order.append("barrier")) + monkeypatch.setattr(lightning.fabric.utilities.distributed._InfiniteBarrier, "__call__", barrier_mock) target = Mock(run=Mock(side_effect=lambda *_: call_order.append("run"))) with fabric.rank_zero_first(): @@ -1145,8 +1147,8 @@ def record_calls_for_rank(rank): return call_order - assert record_calls_for_rank(0) == ["run", "barrier", "barrier"] - assert record_calls_for_rank(1) == ["barrier", "run", "barrier"] + assert record_calls_for_rank(0) == ["run", "barrier"] + assert record_calls_for_rank(1) == ["barrier", "run"] @pytest.mark.parametrize(("clip_val", "max_norm"), [(1e-3, None), (None, 1)]) diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index 4badff47403eb..130286bb6c17d 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -13,6 +13,7 @@ from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher from lightning.fabric.utilities.distributed import ( _gather_all_tensors, + _InfiniteBarrier, _set_num_threads_if_needed, _suggested_max_num_threads, _sync_ddp, @@ -196,3 +197,30 @@ def test_set_num_threads_if_needed(_, set_num_threads_mock, num_processes, expec _set_num_threads_if_needed(1) set_num_threads_mock.assert_not_called() assert os.environ["OMP_NUM_THREADS"] == str(expected) + + +def test_infinite_barrier(): + # distributed not available + barrier = _InfiniteBarrier() + assert barrier.group is None + with mock.patch("lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=False): + barrier.__enter__() + assert barrier.group is None + barrier() + barrier.__exit__(None, None, None) + assert barrier.group is None + + # distributed available + barrier = _InfiniteBarrier() + with mock.patch( + "lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True + ), mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock: + barrier.__enter__() + dist_mock.new_group.assert_called_once() + assert barrier.barrier == barrier.group.monitored_barrier + assert barrier.barrier.call_count == 0 + barrier() + assert barrier.barrier.call_count == 1 + barrier.__exit__(None, None, None) + assert barrier.barrier.call_count == 2 + dist_mock.destroy_process_group.assert_called_once() diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 2e021537c93b9..583271e81968e 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -40,7 +40,13 @@ @mock.patch("lightning.pytorch.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock) @mock.patch("lightning.pytorch.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock) def test_can_prepare_data(local_rank, node_rank): - dm = Mock(spec=LightningDataModule) + class MyDataModule(LightningDataModule): + def prepare_data(self): + pass + + dm = MyDataModule() + dm.prepare_data = Mock(wraps=dm.prepare_data) + dm.prepare_data_per_node = True trainer = Trainer() trainer.datamodule = dm @@ -56,7 +62,7 @@ def test_can_prepare_data(local_rank, node_rank): dm.prepare_data.assert_called_once() # local rank = 1 (False) - dm.reset_mock() + dm.prepare_data.reset_mock() local_rank.return_value = 1 assert trainer.local_rank == 1 @@ -65,7 +71,7 @@ def test_can_prepare_data(local_rank, node_rank): # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) - dm.reset_mock() + dm.prepare_data.reset_mock() dm.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 @@ -74,7 +80,7 @@ def test_can_prepare_data(local_rank, node_rank): dm.prepare_data.assert_called_once() # global rank = 1 (False) - dm.reset_mock() + dm.prepare_data.reset_mock() node_rank.return_value = 1 local_rank.return_value = 0 @@ -465,6 +471,10 @@ class CustomBoringDataModule(BoringDataModule): def state_dict(self): return {"temp": 1} + # override so that it gets called + def prepare_data(self): + pass + model = BoringModel() dm = CustomBoringDataModule() trainer = get_trainer() diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 9d02d37368c46..c8824a1820e63 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -48,6 +48,10 @@ def call(hook, fn, *args, **kwargs): update_wrapper(partial_h, attr) setattr(self, h, partial_h) + # override so that it gets called + def prepare_data(self): + ... + @pytest.mark.parametrize("max_steps", [1, 2, 3]) def test_on_before_zero_grad_called(tmpdir, max_steps): @@ -407,6 +411,10 @@ def on_test_model_train(self): def on_predict_model_train(self): ... + # override so that it gets called + def prepare_data(self): + ... + @pytest.mark.parametrize( "kwargs",