From 662919d1c904abc3a88b12ae3b013c9c5414c56d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 02:43:40 +0100 Subject: [PATCH 01/21] rank zero first --- src/lightning/fabric/fabric.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index bc07e633a911e..e30cc609cd592 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -14,6 +14,7 @@ import inspect import os from contextlib import contextmanager, nullcontext +from datetime import timedelta from functools import partial from pathlib import Path from typing import ( @@ -631,13 +632,20 @@ def rank_zero_first(self, local: bool = False) -> Generator: dataset = MNIST("datasets/", download=True) """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + # Create a barrier with an 'infinite' timeout (only reliably possible over the GLOO backend) + group = torch.distributed.new_group(backend="gloo", timeout=timedelta(days=1000)) + barrier = group.monitored_barrier + else: + barrier = self.barrier + rank = self.local_rank if local else self.global_rank if rank > 0: - self.barrier() + barrier() yield if rank == 0: - self.barrier() - self.barrier() + barrier() + barrier() def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> ContextManager: r"""Skip gradient synchronization during backward to avoid redundant communication overhead. From 6313e4f0bdba5ee75182adcef3d2e08fdcaa9086 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 03:15:10 +0100 Subject: [PATCH 02/21] context --- src/lightning/fabric/fabric.py | 22 ++++------ src/lightning/fabric/utilities/distributed.py | 23 +++++++++++ .../trainer/connectors/data_connector.py | 41 ++++++++++--------- 3 files changed, 51 insertions(+), 35 deletions(-) diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index e30cc609cd592..9bbd7b47a144d 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -14,7 +14,6 @@ import inspect import os from contextlib import contextmanager, nullcontext -from datetime import timedelta from functools import partial from pathlib import Path from typing import ( @@ -66,7 +65,7 @@ has_iterable_dataset, ) from lightning.fabric.utilities.device_dtype_mixin import _update_properties -from lightning.fabric.utilities.distributed import DistributedSamplerWrapper +from lightning.fabric.utilities.distributed import DistributedSamplerWrapper, _InfiniteBarrier from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn from lightning.fabric.utilities.registry import _load_external_callbacks @@ -632,20 +631,13 @@ def rank_zero_first(self, local: bool = False) -> Generator: dataset = MNIST("datasets/", download=True) """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - # Create a barrier with an 'infinite' timeout (only reliably possible over the GLOO backend) - group = torch.distributed.new_group(backend="gloo", timeout=timedelta(days=1000)) - barrier = group.monitored_barrier - else: - barrier = self.barrier - rank = self.local_rank if local else self.global_rank - if rank > 0: - barrier() - yield - if rank == 0: - barrier() - barrier() + with _InfiniteBarrier() as barrier: + if rank > 0: + barrier() + yield + if rank == 0: + barrier() def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> ContextManager: r"""Skip gradient synchronization during backward to avoid redundant communication overhead. diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 16157c66274be..498ece91b4985 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -3,6 +3,7 @@ import os import time from contextlib import nullcontext +from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union @@ -383,3 +384,25 @@ def _distributed_is_initialized() -> bool: # https://github.com/pytorch/pytorch/blob/v2.1.0/torch/distributed/__init__.py#L25 # this might happen to MacOS builds from source (default) or any build from source that sets `USE_DISTRIBUTED=0` return torch.distributed.is_available() and torch.distributed.is_initialized() + + +class _InfiniteBarrier: + """A barrier with an infinite timeout.""" + + def __init__(self) -> None: + # Create a barrier with an 'infinite' timeout (only reliably possible over the GLOO backend) + self.group = None + self.barrier = lambda: None + + def __enter__(self) -> None: + if _distributed_is_initialized(): + self.group = torch.distributed.new_group(backend="gloo", timeout=timedelta(days=10000)) + self.barrier = self.group.monitored_barrier + + def __call__(self) -> None: + self.barrier() + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.barrier() + if self.group is not None: + torch.distributed.destroy_process_group(self.group) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 1bc63c62c561f..b5d949fbb79a2 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -27,7 +27,7 @@ has_iterable_dataset, suggested_max_num_workers, ) -from lightning.fabric.utilities.distributed import DistributedSamplerWrapper +from lightning.fabric.utilities.distributed import DistributedSamplerWrapper, _InfiniteBarrier from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSamplerWrapper from lightning.pytorch.trainer import call from lightning.pytorch.trainer.states import RunningStage, TrainerFn @@ -79,25 +79,26 @@ def on_trainer_init( def prepare_data(self) -> None: trainer = self.trainer - # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 - # or in the case where each node needs to do its own manipulation in which case just local_rank=0 - local_rank_zero = trainer.local_rank == 0 - global_rank_zero = trainer.local_rank == 0 and trainer.node_rank == 0 - - datamodule = trainer.datamodule - lightning_module = trainer.lightning_module - # handle datamodule prepare data: - # check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data - if datamodule is not None: - dm_prepare_data_per_node = datamodule.prepare_data_per_node - if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero): - call._call_lightning_datamodule_hook(trainer, "prepare_data") - # handle lightning module prepare data: - # check for prepare_data_per_node before calling lightning_module.prepare_data - if lightning_module is not None: - lm_prepare_data_per_node = lightning_module.prepare_data_per_node - if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero): - call._call_lightning_module_hook(trainer, "prepare_data") + with _InfiniteBarrier(): + # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 + # or in the case where each node needs to do its own manipulation in which case just local_rank=0 + local_rank_zero = trainer.local_rank == 0 + global_rank_zero = trainer.local_rank == 0 and trainer.node_rank == 0 + + datamodule = trainer.datamodule + lightning_module = trainer.lightning_module + # handle datamodule prepare data: + # check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data + if datamodule is not None: + dm_prepare_data_per_node = datamodule.prepare_data_per_node + if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero): + call._call_lightning_datamodule_hook(trainer, "prepare_data") + # handle lightning module prepare data: + # check for prepare_data_per_node before calling lightning_module.prepare_data + if lightning_module is not None: + lm_prepare_data_per_node = lightning_module.prepare_data_per_node + if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero): + call._call_lightning_module_hook(trainer, "prepare_data") def attach_data( self, From c53f952a4deeba3340347ac6007d6eecad8466e4 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 03:21:35 +0100 Subject: [PATCH 03/21] docstring --- src/lightning/fabric/utilities/distributed.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 498ece91b4985..01e30aa72bf34 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -387,7 +387,12 @@ def _distributed_is_initialized() -> bool: class _InfiniteBarrier: - """A barrier with an infinite timeout.""" + """A barrier with an infinite timeout. + + Creates a new process group with the GLOO backend with a very high timeout that makes the barrier effectively + wait forever. This is useful in cases where you want to execute a long-running operation on a subset of ranks + that should not be subject to the regular collective timeout. + """ def __init__(self) -> None: # Create a barrier with an 'infinite' timeout (only reliably possible over the GLOO backend) From 1df9c69e98129701c383de022b7aa76d892ab9bc Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 03:25:26 +0100 Subject: [PATCH 04/21] refactor --- .../trainer/connectors/data_connector.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index b5d949fbb79a2..5ac60b5cd6d39 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -78,26 +78,30 @@ def on_trainer_init( def prepare_data(self) -> None: trainer = self.trainer + # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 + # or in the case where each node needs to do its own manipulation in which case just local_rank=0 + local_rank_zero = trainer.local_rank == 0 + global_rank_zero = trainer.local_rank == 0 and trainer.node_rank == 0 - with _InfiniteBarrier(): - # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 - # or in the case where each node needs to do its own manipulation in which case just local_rank=0 - local_rank_zero = trainer.local_rank == 0 - global_rank_zero = trainer.local_rank == 0 and trainer.node_rank == 0 + datamodule = trainer.datamodule + lightning_module = trainer.lightning_module - datamodule = trainer.datamodule - lightning_module = trainer.lightning_module + with _InfiniteBarrier(): # handle datamodule prepare data: # check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data if datamodule is not None: dm_prepare_data_per_node = datamodule.prepare_data_per_node - if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero): + if (dm_prepare_data_per_node and local_rank_zero) or ( + not dm_prepare_data_per_node and global_rank_zero + ): call._call_lightning_datamodule_hook(trainer, "prepare_data") # handle lightning module prepare data: # check for prepare_data_per_node before calling lightning_module.prepare_data if lightning_module is not None: lm_prepare_data_per_node = lightning_module.prepare_data_per_node - if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero): + if (lm_prepare_data_per_node and local_rank_zero) or ( + not lm_prepare_data_per_node and global_rank_zero + ): call._call_lightning_module_hook(trainer, "prepare_data") def attach_data( From 036b499e0e1fe5f477fb5f6bb20fff1a13803cb2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Feb 2024 02:28:20 +0000 Subject: [PATCH 05/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/utilities/distributed.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 01e30aa72bf34..32384db993d56 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -389,9 +389,10 @@ def _distributed_is_initialized() -> bool: class _InfiniteBarrier: """A barrier with an infinite timeout. - Creates a new process group with the GLOO backend with a very high timeout that makes the barrier effectively - wait forever. This is useful in cases where you want to execute a long-running operation on a subset of ranks - that should not be subject to the regular collective timeout. + Creates a new process group with the GLOO backend with a very high timeout that makes the barrier effectively wait + forever. This is useful in cases where you want to execute a long-running operation on a subset of ranks that should + not be subject to the regular collective timeout. + """ def __init__(self) -> None: From 60fc068364c73d69176c174608cb25ed637c5874 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 03:44:32 +0100 Subject: [PATCH 06/21] ctx manager --- src/lightning/fabric/utilities/distributed.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 32384db993d56..e11af44ffcecc 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -12,7 +12,7 @@ from lightning_utilities.core.imports import package_available from torch import Tensor from torch.utils.data import Dataset, DistributedSampler, Sampler -from typing_extensions import override +from typing_extensions import override, Self from lightning.fabric.utilities.cloud_io import _is_local_file_protocol from lightning.fabric.utilities.data import _num_cpus_available @@ -392,7 +392,6 @@ class _InfiniteBarrier: Creates a new process group with the GLOO backend with a very high timeout that makes the barrier effectively wait forever. This is useful in cases where you want to execute a long-running operation on a subset of ranks that should not be subject to the regular collective timeout. - """ def __init__(self) -> None: @@ -400,13 +399,14 @@ def __init__(self) -> None: self.group = None self.barrier = lambda: None - def __enter__(self) -> None: + def __call__(self) -> None: + self.barrier() + + def __enter__(self) -> Self: if _distributed_is_initialized(): self.group = torch.distributed.new_group(backend="gloo", timeout=timedelta(days=10000)) self.barrier = self.group.monitored_barrier - - def __call__(self) -> None: - self.barrier() + return self def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.barrier() From 9115289553226a3ed8d86412acede1c1641e6ee2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Feb 2024 02:45:29 +0000 Subject: [PATCH 07/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/utilities/distributed.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index e11af44ffcecc..77645f0a2ada7 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -12,7 +12,7 @@ from lightning_utilities.core.imports import package_available from torch import Tensor from torch.utils.data import Dataset, DistributedSampler, Sampler -from typing_extensions import override, Self +from typing_extensions import Self, override from lightning.fabric.utilities.cloud_io import _is_local_file_protocol from lightning.fabric.utilities.data import _num_cpus_available @@ -392,6 +392,7 @@ class _InfiniteBarrier: Creates a new process group with the GLOO backend with a very high timeout that makes the barrier effectively wait forever. This is useful in cases where you want to execute a long-running operation on a subset of ranks that should not be subject to the regular collective timeout. + """ def __init__(self) -> None: From 3597de137f1251aac8e76fc76a043f6927f7900b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 03:45:52 +0100 Subject: [PATCH 08/21] comment --- src/lightning/fabric/utilities/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index e11af44ffcecc..14e9ae217aefd 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -395,7 +395,6 @@ class _InfiniteBarrier: """ def __init__(self) -> None: - # Create a barrier with an 'infinite' timeout (only reliably possible over the GLOO backend) self.group = None self.barrier = lambda: None @@ -404,6 +403,7 @@ def __call__(self) -> None: def __enter__(self) -> Self: if _distributed_is_initialized(): + # Create a barrier with an 'infinite' timeout (only reliably possible over the GLOO backend) self.group = torch.distributed.new_group(backend="gloo", timeout=timedelta(days=10000)) self.barrier = self.group.monitored_barrier return self From 4a336a3870f2a3bfeccf7a7dda0dca896805cc7c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 03:52:30 +0100 Subject: [PATCH 09/21] chlog --- src/lightning/fabric/CHANGELOG.md | 2 +- src/lightning/pytorch/CHANGELOG.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 9e32953100bfa..40522739b710d 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -17,7 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- The `Fabric.rank_zero_first` context manager now uses a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448)) - diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index f0de46f3fda16..4cad518eca057 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- The `prepare_data()` hook in `LightningModule` and `LightningDataModule` is now subject to a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448)) - From 9e26e99aceb335f3ba5082c1df176845ecfe9d64 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 04:09:22 +0100 Subject: [PATCH 10/21] test --- .../utilities/test_distributed.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index 4badff47403eb..5a31598580b67 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -17,6 +17,7 @@ _suggested_max_num_threads, _sync_ddp, is_shared_filesystem, + _InfiniteBarrier, ) from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 @@ -196,3 +197,29 @@ def test_set_num_threads_if_needed(_, set_num_threads_mock, num_processes, expec _set_num_threads_if_needed(1) set_num_threads_mock.assert_not_called() assert os.environ["OMP_NUM_THREADS"] == str(expected) + + +def test_infinite_barrier(): + # distributed not available + barrier = _InfiniteBarrier() + assert barrier.group is None + with mock.patch("lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=False): + barrier.__enter__() + assert barrier.group is None + barrier() + barrier.__exit__(None, None, None) + assert barrier.group is None + + # distributed available + barrier = _InfiniteBarrier() + with mock.patch("lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True): + with mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock: + barrier.__enter__() + dist_mock.new_group.assert_called_once() + assert barrier.barrier == barrier.group.monitored_barrier + assert barrier.barrier.call_count == 0 + barrier() + assert barrier.barrier.call_count == 1 + barrier.__exit__(None, None, None) + assert barrier.barrier.call_count == 2 + dist_mock.destroy_process_group.assert_called_once() From 4f73ad209ef36c2f58ada030c206a898214c3b6f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Feb 2024 03:10:28 +0000 Subject: [PATCH 11/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/utilities/test_distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index 5a31598580b67..78ec3e548763f 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -13,11 +13,11 @@ from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher from lightning.fabric.utilities.distributed import ( _gather_all_tensors, + _InfiniteBarrier, _set_num_threads_if_needed, _suggested_max_num_threads, _sync_ddp, is_shared_filesystem, - _InfiniteBarrier, ) from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 From 9e13327f6c81a792a97b830cb2c2195028030750 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 04:14:25 +0100 Subject: [PATCH 12/21] optimize --- .../trainer/connectors/data_connector.py | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 5ac60b5cd6d39..adf976eaa630e 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -78,6 +78,7 @@ def on_trainer_init( def prepare_data(self) -> None: trainer = self.trainer + # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 local_rank_zero = trainer.local_rank == 0 @@ -85,23 +86,20 @@ def prepare_data(self) -> None: datamodule = trainer.datamodule lightning_module = trainer.lightning_module - - with _InfiniteBarrier(): - # handle datamodule prepare data: - # check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data - if datamodule is not None: - dm_prepare_data_per_node = datamodule.prepare_data_per_node - if (dm_prepare_data_per_node and local_rank_zero) or ( - not dm_prepare_data_per_node and global_rank_zero - ): + # handle datamodule prepare data: + # check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data + if datamodule is not None and is_overridden("prepare_data", datamodule): + dm_prepare_data_per_node = datamodule.prepare_data_per_node + with _InfiniteBarrier(): + if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero): call._call_lightning_datamodule_hook(trainer, "prepare_data") - # handle lightning module prepare data: - # check for prepare_data_per_node before calling lightning_module.prepare_data - if lightning_module is not None: - lm_prepare_data_per_node = lightning_module.prepare_data_per_node - if (lm_prepare_data_per_node and local_rank_zero) or ( - not lm_prepare_data_per_node and global_rank_zero - ): + + # handle lightning module prepare data: + # check for prepare_data_per_node before calling lightning_module.prepare_data + if lightning_module is not None and is_overridden("prepare_data", lightning_module): + lm_prepare_data_per_node = lightning_module.prepare_data_per_node + with _InfiniteBarrier(): + if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero): call._call_lightning_module_hook(trainer, "prepare_data") def attach_data( From c628d7cf07ae05292bc86647ca598cec572b8286 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 04:20:09 +0100 Subject: [PATCH 13/21] line too long --- .../pytorch/trainer/connectors/data_connector.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index adf976eaa630e..40887810f0065 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -89,17 +89,17 @@ def prepare_data(self) -> None: # handle datamodule prepare data: # check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data if datamodule is not None and is_overridden("prepare_data", datamodule): - dm_prepare_data_per_node = datamodule.prepare_data_per_node + prepare_data_per_node = datamodule.prepare_data_per_node with _InfiniteBarrier(): - if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero): + if (prepare_data_per_node and local_rank_zero) or (not prepare_data_per_node and global_rank_zero): call._call_lightning_datamodule_hook(trainer, "prepare_data") # handle lightning module prepare data: # check for prepare_data_per_node before calling lightning_module.prepare_data if lightning_module is not None and is_overridden("prepare_data", lightning_module): - lm_prepare_data_per_node = lightning_module.prepare_data_per_node + prepare_data_per_node = lightning_module.prepare_data_per_node with _InfiniteBarrier(): - if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero): + if (prepare_data_per_node and local_rank_zero) or (not prepare_data_per_node and global_rank_zero): call._call_lightning_module_hook(trainer, "prepare_data") def attach_data( From 869a03d5b20f007558646e72283a57cb473099f7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 04:22:15 +0100 Subject: [PATCH 14/21] remove redundant comments --- src/lightning/pytorch/trainer/connectors/data_connector.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 40887810f0065..8d9b5e82cbc94 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -87,7 +87,6 @@ def prepare_data(self) -> None: datamodule = trainer.datamodule lightning_module = trainer.lightning_module # handle datamodule prepare data: - # check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data if datamodule is not None and is_overridden("prepare_data", datamodule): prepare_data_per_node = datamodule.prepare_data_per_node with _InfiniteBarrier(): @@ -95,7 +94,6 @@ def prepare_data(self) -> None: call._call_lightning_datamodule_hook(trainer, "prepare_data") # handle lightning module prepare data: - # check for prepare_data_per_node before calling lightning_module.prepare_data if lightning_module is not None and is_overridden("prepare_data", lightning_module): prepare_data_per_node = lightning_module.prepare_data_per_node with _InfiniteBarrier(): From f6c69c038eafdf7d86c1ffd8254d65e4fa64ed75 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 04:27:59 +0100 Subject: [PATCH 15/21] update test --- tests/tests_pytorch/models/test_hooks.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 9d02d37368c46..c8824a1820e63 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -48,6 +48,10 @@ def call(hook, fn, *args, **kwargs): update_wrapper(partial_h, attr) setattr(self, h, partial_h) + # override so that it gets called + def prepare_data(self): + ... + @pytest.mark.parametrize("max_steps", [1, 2, 3]) def test_on_before_zero_grad_called(tmpdir, max_steps): @@ -407,6 +411,10 @@ def on_test_model_train(self): def on_predict_model_train(self): ... + # override so that it gets called + def prepare_data(self): + ... + @pytest.mark.parametrize( "kwargs", From c7ef0beb5d365958a1fc8cc0b9c9e627fdc95de1 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 04:31:48 +0100 Subject: [PATCH 16/21] fix test --- tests/tests_pytorch/core/test_datamodules.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 2e021537c93b9..0e837b834c5f2 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -465,6 +465,10 @@ class CustomBoringDataModule(BoringDataModule): def state_dict(self): return {"temp": 1} + # override so that it gets called + def prepare_data(self): + pass + model = BoringModel() dm = CustomBoringDataModule() trainer = get_trainer() From a14a5913c89c8446d7ac6817cac8b06921820c0b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 04:35:59 +0100 Subject: [PATCH 17/21] precommit --- .../utilities/test_distributed.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index 78ec3e548763f..523f7d3fc27bb 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -212,14 +212,16 @@ def test_infinite_barrier(): # distributed available barrier = _InfiniteBarrier() - with mock.patch("lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True): - with mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock: - barrier.__enter__() - dist_mock.new_group.assert_called_once() - assert barrier.barrier == barrier.group.monitored_barrier - assert barrier.barrier.call_count == 0 - barrier() - assert barrier.barrier.call_count == 1 - barrier.__exit__(None, None, None) - assert barrier.barrier.call_count == 2 - dist_mock.destroy_process_group.assert_called_once() + with ( + mock.patch("lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True), + mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock + ): + barrier.__enter__() + dist_mock.new_group.assert_called_once() + assert barrier.barrier == barrier.group.monitored_barrier + assert barrier.barrier.call_count == 0 + barrier() + assert barrier.barrier.call_count == 1 + barrier.__exit__(None, None, None) + assert barrier.barrier.call_count == 2 + dist_mock.destroy_process_group.assert_called_once() From 94292788cc2d1b2c5c74e7554d32091d0e91f382 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 14:09:06 +0100 Subject: [PATCH 18/21] fix test --- tests/tests_pytorch/core/test_datamodules.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 0e837b834c5f2..583271e81968e 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -40,7 +40,13 @@ @mock.patch("lightning.pytorch.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock) @mock.patch("lightning.pytorch.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock) def test_can_prepare_data(local_rank, node_rank): - dm = Mock(spec=LightningDataModule) + class MyDataModule(LightningDataModule): + def prepare_data(self): + pass + + dm = MyDataModule() + dm.prepare_data = Mock(wraps=dm.prepare_data) + dm.prepare_data_per_node = True trainer = Trainer() trainer.datamodule = dm @@ -56,7 +62,7 @@ def test_can_prepare_data(local_rank, node_rank): dm.prepare_data.assert_called_once() # local rank = 1 (False) - dm.reset_mock() + dm.prepare_data.reset_mock() local_rank.return_value = 1 assert trainer.local_rank == 1 @@ -65,7 +71,7 @@ def test_can_prepare_data(local_rank, node_rank): # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) - dm.reset_mock() + dm.prepare_data.reset_mock() dm.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 @@ -74,7 +80,7 @@ def test_can_prepare_data(local_rank, node_rank): dm.prepare_data.assert_called_once() # global rank = 1 (False) - dm.reset_mock() + dm.prepare_data.reset_mock() node_rank.return_value = 1 local_rank.return_value = 0 From ba99259bf79609749286379a6c11e58f3214a128 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 18:25:21 +0100 Subject: [PATCH 19/21] update test --- tests/tests_fabric/test_fabric.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 2c860eb45d78d..3c5346a082950 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -20,6 +20,7 @@ import torch import torch.distributed import torch.nn.functional +import lightning.fabric from lightning.fabric.fabric import Fabric from lightning.fabric.plugins import Precision from lightning.fabric.strategies import ( @@ -1129,7 +1130,7 @@ def test_all_reduce(): fabric._strategy.all_reduce.assert_has_calls([call(torch.tensor(4), **defaults), call(torch.tensor(5), **defaults)]) -def test_rank_zero_first(): +def test_rank_zero_first(monkeypatch): """Test that rank 0 completes first before all other processes can execute under `.rank_zero_first()`.""" def record_calls_for_rank(rank): @@ -1137,7 +1138,8 @@ def record_calls_for_rank(rank): fabric = Fabric() fabric._strategy = Mock(global_rank=rank) - fabric.barrier = Mock(side_effect=lambda *_: call_order.append("barrier")) + barrier_mock = MagicMock(side_effect=lambda *_: call_order.append("barrier")) + monkeypatch.setattr(lightning.fabric.utilities.distributed._InfiniteBarrier, "__call__", barrier_mock) target = Mock(run=Mock(side_effect=lambda *_: call_order.append("run"))) with fabric.rank_zero_first(): @@ -1145,8 +1147,8 @@ def record_calls_for_rank(rank): return call_order - assert record_calls_for_rank(0) == ["run", "barrier", "barrier"] - assert record_calls_for_rank(1) == ["barrier", "run", "barrier"] + assert record_calls_for_rank(0) == ["run", "barrier"] + assert record_calls_for_rank(1) == ["barrier", "run"] @pytest.mark.parametrize(("clip_val", "max_norm"), [(1e-3, None), (None, 1)]) From f88ae48620c34df5b9ceda821102f15eb5ffc7e7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Feb 2024 17:27:00 +0000 Subject: [PATCH 20/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/test_fabric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 3c5346a082950..5eb39a0ad4b91 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -16,11 +16,11 @@ from unittest import mock from unittest.mock import ANY, MagicMock, Mock, PropertyMock, call +import lightning.fabric import pytest import torch import torch.distributed import torch.nn.functional -import lightning.fabric from lightning.fabric.fabric import Fabric from lightning.fabric.plugins import Precision from lightning.fabric.strategies import ( From e27d7f1e6dfcf69136b25ce1ad9a46d0ff935a00 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 12 Feb 2024 22:18:37 +0100 Subject: [PATCH 21/21] formatting --- tests/tests_fabric/utilities/test_distributed.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index 523f7d3fc27bb..130286bb6c17d 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -212,10 +212,9 @@ def test_infinite_barrier(): # distributed available barrier = _InfiniteBarrier() - with ( - mock.patch("lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True), - mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock - ): + with mock.patch( + "lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True + ), mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock: barrier.__enter__() dist_mock.new_group.assert_called_once() assert barrier.barrier == barrier.group.monitored_barrier