From 2ca329ed75cface0fae06ff228e90cbd1cda203f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 19 Sep 2023 05:22:04 +0200 Subject: [PATCH 01/28] add suggested num workers --- src/lightning/pytorch/utilities/data.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index b026badb65c8a..a360b4601e4a2 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import os from dataclasses import fields from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union @@ -38,6 +39,19 @@ warning_cache = WarningCache() +def suggested_max_num_workers(local_world_size: int) -> int: + cpu_count = _num_cpus_available() + return max(1, cpu_count // local_world_size) + + +def _num_cpus_available(): + if hasattr(os, 'sched_getaffinity'): + return len(os.sched_getaffinity(0)) + + cpu_count = os.cpu_count() + return 1 if cpu_count is None else cpu_count + + def _extract_batch_size(batch: BType) -> Generator[Optional[int], None, None]: if isinstance(batch, Tensor): if batch.ndim == 0: From 2db36800ca66aca4863ae4d2ae3be76248654ff1 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 19 Sep 2023 12:56:09 +0200 Subject: [PATCH 02/28] wip --- .../trainer/connectors/data_connector.py | 53 +++++++++---------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index ed50f4a06b7ce..6a1046d11ef92 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -32,7 +32,7 @@ from lightning.pytorch.trainer import call from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.combined_loader import CombinedLoader -from lightning.pytorch.utilities.data import _is_dataloader_shuffled, _update_dataloader +from lightning.pytorch.utilities.data import _is_dataloader_shuffled, _update_dataloader, suggested_max_num_workers from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _lightning_graphcore_available from lightning.pytorch.utilities.model_helpers import is_overridden @@ -420,35 +420,36 @@ def _check_dataloader_iterable( ) -def _worker_check(dataloader: object, using_spawn: bool, name: str) -> None: +def _worker_check(trainer: "pl.Trainer", dataloader: object, name: str) -> None: if not isinstance(dataloader, DataLoader): return - num_cpus = multiprocessing.cpu_count() - - # ddp_spawn + num_workers > 0 don't mix! tell the user - if dataloader.num_workers > 0 and using_spawn: - if not dataloader.persistent_workers: - rank_zero_warn( - "num_workers>0, persistent_workers=False, and strategy=ddp_spawn" - " may result in data loading bottlenecks." - " Consider setting persistent_workers=True" - " (this is a limitation of Python .spawn() and PyTorch)" - ) - - elif dataloader.num_workers == 0 and using_spawn: - if not dataloader.persistent_workers: - rank_zero_warn( - "strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks." - " Consider setting num_workers>0 and persistent_workers=True" - ) - - elif dataloader.num_workers <= 2 < num_cpus and not using_spawn: + upper_bound = suggested_max_num_workers(trainer.num_devices) + + # # ddp_spawn + num_workers > 0 don't mix! tell the user + # if dataloader.num_workers > 0 and using_spawn: + # if not dataloader.persistent_workers: + # rank_zero_warn( + # "num_workers>0, persistent_workers=False, and strategy=ddp_spawn" + # " may result in data loading bottlenecks." + # " Consider setting persistent_workers=True" + # " (this is a limitation of Python .spawn() and PyTorch)" + # ) + + # elif dataloader.num_workers == 0 and using_spawn: + # if not dataloader.persistent_workers: + # rank_zero_warn( + # "strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks." + # " Consider setting num_workers>0 and persistent_workers=True" + # ) + + if dataloader.num_workers <= 2 < upper_bound: + # TODO: find filterwarnings # if changed, update the `filterwarnings` snippet in 'speed.html#num-workers' rank_zero_warn( f"The dataloader, {name}, does not have many workers which may be a bottleneck." " Consider increasing the value of the `num_workers` argument`" - f" (try {num_cpus} which is the number of cpus on this machine)" + f" (try {upper_bound} which is the number of cpus on this machine)" " in the `DataLoader` init to improve performance.", category=PossibleUserWarning, ) @@ -506,11 +507,7 @@ def _process_dataloader( dataloader = strategy.process_dataloader(dataloader) # check the workers - _worker_check( - dataloader, - isinstance(strategy, DDPStrategy) and strategy._start_method == "spawn", - f"{stage.dataloader_prefix}_dataloader", - ) + _worker_check(trainer, dataloader, f"{stage.dataloader_prefix}_dataloader") # add worker_init_fn for correct seeding in worker processes _auto_add_worker_init_fn(dataloader, trainer.global_rank) From c5c5de2410c60acd96c13030fdb5d2597fd627fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Sep 2023 15:46:49 +0200 Subject: [PATCH 03/28] update --- .../trainer/connectors/data_connector.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 6a1046d11ef92..e0c17c2c5d223 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -425,26 +425,7 @@ def _worker_check(trainer: "pl.Trainer", dataloader: object, name: str) -> None: return upper_bound = suggested_max_num_workers(trainer.num_devices) - - # # ddp_spawn + num_workers > 0 don't mix! tell the user - # if dataloader.num_workers > 0 and using_spawn: - # if not dataloader.persistent_workers: - # rank_zero_warn( - # "num_workers>0, persistent_workers=False, and strategy=ddp_spawn" - # " may result in data loading bottlenecks." - # " Consider setting persistent_workers=True" - # " (this is a limitation of Python .spawn() and PyTorch)" - # ) - - # elif dataloader.num_workers == 0 and using_spawn: - # if not dataloader.persistent_workers: - # rank_zero_warn( - # "strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks." - # " Consider setting num_workers>0 and persistent_workers=True" - # ) - if dataloader.num_workers <= 2 < upper_bound: - # TODO: find filterwarnings # if changed, update the `filterwarnings` snippet in 'speed.html#num-workers' rank_zero_warn( f"The dataloader, {name}, does not have many workers which may be a bottleneck." From ada09edfa366a79a7ded7554eb7800e54f81d405 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Sep 2023 16:20:34 +0200 Subject: [PATCH 04/28] test --- .../trainer/connectors/data_connector.py | 8 +-- src/lightning/pytorch/utilities/data.py | 2 +- .../trainer/connectors/test_data_connector.py | 57 +++++++------------ 3 files changed, 25 insertions(+), 42 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index e0c17c2c5d223..71329985d6a35 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -426,12 +426,12 @@ def _worker_check(trainer: "pl.Trainer", dataloader: object, name: str) -> None: upper_bound = suggested_max_num_workers(trainer.num_devices) if dataloader.num_workers <= 2 < upper_bound: + # TODO # if changed, update the `filterwarnings` snippet in 'speed.html#num-workers' rank_zero_warn( - f"The dataloader, {name}, does not have many workers which may be a bottleneck." - " Consider increasing the value of the `num_workers` argument`" - f" (try {upper_bound} which is the number of cpus on this machine)" - " in the `DataLoader` init to improve performance.", + f"The dataloader, {name}, does not have many workers which may be a bottleneck. Consider increasing the" + f" value of the `num_workers` argument` to `num_workers={upper_bound}` in the `DataLoader` to improve " + " performance.", category=PossibleUserWarning, ) diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index a360b4601e4a2..91c990947d054 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -44,7 +44,7 @@ def suggested_max_num_workers(local_world_size: int) -> int: return max(1, cpu_count // local_world_size) -def _num_cpus_available(): +def _num_cpus_available() -> int: if hasattr(os, 'sched_getaffinity'): return len(os.sched_getaffinity(0)) diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 7c9dc9126dc0e..6251fe61c2369 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -15,6 +15,7 @@ from io import StringIO from re import escape from typing import Sized +from unittest import mock from unittest.mock import Mock import pytest @@ -30,7 +31,7 @@ _check_dataloader_iterable, _DataHookSelector, _DataLoaderSource, - warning_cache, + warning_cache, _worker_check, ) from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.combined_loader import CombinedLoader @@ -106,44 +107,26 @@ def test_dataloader(self): trainer.test(model) -class TestSpawnBoringModel(BoringModel): - def __init__(self, num_workers): - super().__init__() - self.num_workers = num_workers - - def train_dataloader(self): - return DataLoader(RandomDataset(32, 64), num_workers=self.num_workers) - - def on_fit_start(self): - self._resout = StringIO() - self.ctx = redirect_stderr(self._resout) - self.ctx.__enter__() - - def on_train_end(self): - def _get_warning_msg(): - dl = self.trainer.train_dataloader - if hasattr(dl, "persistent_workers"): - if self.num_workers == 0: - warn_str = "Consider setting num_workers>0 and persistent_workers=True" - else: - warn_str = "Consider setting persistent_workers=True" - else: - warn_str = "Consider setting strategy=ddp" - - return warn_str - - if self.trainer.is_global_zero: - self.ctx.__exit__(None, None, None) - msg = self._resout.getvalue() - warn_str = _get_warning_msg() - assert warn_str in msg +@pytest.mark.parametrize("num_devices, num_workers, cpu_count, expected_warning", [ + (1, 0, 1, False), + (1, 0, 1, False), +]) +@mock.patch("lightning.pytorch.utilities.data.os.cpu_count") +def test_worker_check(cpu_count_mock, num_devices, num_workers, cpu_count, expected_warning, monkeypatch): + monkeypatch.delattr("lightning.pytorch.utilities.data.os", "sched_getaffinity", raising=False) + trainer = Mock(spec=Trainer) + dataloader = Mock(spec=DataLoader) + trainer.num_devices = num_devices + dataloader.num_workers = num_workers + cpu_count_mock.return_value = cpu_count + if expected_warning: + ctx = pytest.warns(UserWarning, match=f"Consider increasing the value of the `num_workers` argument`") + else: + ctx = no_warning_call(UserWarning) -@RunIf(skip_windows=True) -@pytest.mark.parametrize("num_workers", [0, 1]) -def test_dataloader_warnings(tmpdir, num_workers): - trainer = Trainer(default_root_dir=tmpdir, accelerator="cpu", devices=2, strategy="ddp_spawn", fast_dev_run=4) - trainer.fit(TestSpawnBoringModel(num_workers)) + with ctx: + _worker_check(trainer, dataloader, "train_dataloader") def test_update_dataloader_raises(): From b6dfaeaab860578b8b02297582e139c581d4db4a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 19 Sep 2023 17:08:06 +0200 Subject: [PATCH 05/28] update --- .../pytorch/trainer/connectors/data_connector.py | 2 +- .../trainer/connectors/test_data_connector.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 71329985d6a35..8974c95d54216 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -425,7 +425,7 @@ def _worker_check(trainer: "pl.Trainer", dataloader: object, name: str) -> None: return upper_bound = suggested_max_num_workers(trainer.num_devices) - if dataloader.num_workers <= 2 < upper_bound: + if dataloader.num_workers <= 2 < upper_bound or dataloader.num_workers < 2 <= upper_bound: # TODO # if changed, update the `filterwarnings` snippet in 'speed.html#num-workers' rank_zero_warn( diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 6251fe61c2369..b3c3398ce75b4 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -109,7 +109,16 @@ def test_dataloader(self): @pytest.mark.parametrize("num_devices, num_workers, cpu_count, expected_warning", [ (1, 0, 1, False), - (1, 0, 1, False), + (8, 0, 1, False), + (8, 0, None, False), + (1, 1, None, False), + (1, 2, 2, False), + (1, 1, 8, True), + (1, 2, 8, True), + (1, 3, 8, False), + (4, 1, 8, True), + (4, 2, 8, False), + (8, 2, 8, False), ]) @mock.patch("lightning.pytorch.utilities.data.os.cpu_count") def test_worker_check(cpu_count_mock, num_devices, num_workers, cpu_count, expected_warning, monkeypatch): From 4dff11373f451a99fe3a52572f96df47a6750a3e Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 19 Sep 2023 17:27:55 +0200 Subject: [PATCH 06/28] test --- src/lightning/pytorch/utilities/data.py | 2 + .../trainer/connectors/test_data_connector.py | 32 ++++++------ tests/tests_pytorch/utilities/test_data.py | 49 +++++++++++++++++++ 3 files changed, 69 insertions(+), 14 deletions(-) diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index 91c990947d054..9e362fd72f5a4 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -40,6 +40,8 @@ def suggested_max_num_workers(local_world_size: int) -> int: + if local_world_size < 1: + raise ValueError(f"`local_world_size` should be >= 1, got {local_world_size}.") cpu_count = _num_cpus_available() return max(1, cpu_count // local_world_size) diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index b3c3398ce75b4..210bd553c8930 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -31,7 +31,8 @@ _check_dataloader_iterable, _DataHookSelector, _DataLoaderSource, - warning_cache, _worker_check, + warning_cache, + _worker_check, ) from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.combined_loader import CombinedLoader @@ -107,19 +108,22 @@ def test_dataloader(self): trainer.test(model) -@pytest.mark.parametrize("num_devices, num_workers, cpu_count, expected_warning", [ - (1, 0, 1, False), - (8, 0, 1, False), - (8, 0, None, False), - (1, 1, None, False), - (1, 2, 2, False), - (1, 1, 8, True), - (1, 2, 8, True), - (1, 3, 8, False), - (4, 1, 8, True), - (4, 2, 8, False), - (8, 2, 8, False), -]) +@pytest.mark.parametrize( + "num_devices, num_workers, cpu_count, expected_warning", + [ + (1, 0, 1, False), + (8, 0, 1, False), + (8, 0, None, False), + (1, 1, None, False), + (1, 2, 2, False), + (1, 1, 8, True), + (1, 2, 8, True), + (1, 3, 8, False), + (4, 1, 8, True), + (4, 2, 8, False), + (8, 2, 8, False), + ], +) @mock.patch("lightning.pytorch.utilities.data.os.cpu_count") def test_worker_check(cpu_count_mock, num_devices, num_workers, cpu_count, expected_warning, monkeypatch): monkeypatch.delattr("lightning.pytorch.utilities.data.os", "sched_getaffinity", raising=False) diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index e62a43ee14e8e..ea959a7a24625 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -1,4 +1,6 @@ +import os from dataclasses import dataclass +from unittest import mock from unittest.mock import Mock import numpy as np @@ -19,6 +21,7 @@ extract_batch_size, has_len_all_ranks, warning_cache, + suggested_max_num_workers, ) from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -326,3 +329,49 @@ def __init__(self, indices=None, **kwargs): dataloader = ArrayAttributeDataloader(dataset) dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler) assert dl_kwargs["indices"] is dataloader.indices + + +@pytest.mark.parametrize( + "cpu_count, local_world_size, expected", + [ + (0, 1, 1), + (1, 1, 1), + (2, 1, 2), + (1, 2, 1), + (1, 2, 1), + (2, 2, 1), + (3, 2, 1), + (4, 2, 2), + (4, 3, 1), + (4, 1, 4), + ], +) +@pytest.mark.parametrize( + "affinity", + [ + False, + pytest.param( + True, + marks=pytest.mark.skipif( + not hasattr(os, "sched_getaffinity"), reason="OS does not support restricting CPU cores" + ), + ), + ], +) +@mock.patch("lightning.pytorch.utilities.data.os.cpu_count") +def test_suggested_max_num_workers(cpu_count_mock, affinity, cpu_count, local_world_size, expected, monkeypatch): + if affinity: + monkeypatch.setattr( + "lightning.pytorch.utilities.data.os", "sched_getaffinity", lambda _: list(range(cpu_count)) + ) + else: + monkeypatch.delattr("lightning.pytorch.utilities.data.os", "sched_getaffinity", raising=False) + cpu_count_mock.return_value = cpu_count + + assert suggested_max_num_workers(local_world_size) == expected + + +@pytest.mark.parametrize("invalid", [-1, 0]) +def test_suggested_max_num_workers_input_validation(invalid): + with pytest.raises(ValueError, match="should be >= 1"): + suggested_max_num_workers(invalid) From 63628f43effdf86f7dedd33af19ea07528c9f74a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 19 Sep 2023 17:32:31 +0200 Subject: [PATCH 07/28] format --- src/lightning/pytorch/trainer/connectors/data_connector.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 8974c95d54216..a6c29b0f16c89 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -429,9 +429,8 @@ def _worker_check(trainer: "pl.Trainer", dataloader: object, name: str) -> None: # TODO # if changed, update the `filterwarnings` snippet in 'speed.html#num-workers' rank_zero_warn( - f"The dataloader, {name}, does not have many workers which may be a bottleneck. Consider increasing the" - f" value of the `num_workers` argument` to `num_workers={upper_bound}` in the `DataLoader` to improve " - " performance.", + f"The '{name}', does not have many workers which may be a bottleneck. Consider increasing the value of the" + f" `num_workers` argument` to `num_workers={upper_bound}` in the `DataLoader` to improve performance.", category=PossibleUserWarning, ) From 84f6dea36c84f69e0d51c372551b622898f48457 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 19 Sep 2023 17:36:55 +0200 Subject: [PATCH 08/28] docs --- src/lightning/pytorch/utilities/__init__.py | 1 + src/lightning/pytorch/utilities/data.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/src/lightning/pytorch/utilities/__init__.py b/src/lightning/pytorch/utilities/__init__.py index e2e0c0a8d941e..645fd9545ebc7 100644 --- a/src/lightning/pytorch/utilities/__init__.py +++ b/src/lightning/pytorch/utilities/__init__.py @@ -18,6 +18,7 @@ from lightning.fabric.utilities import LightningEnum # noqa: F401 from lightning.fabric.utilities import move_data_to_device # noqa: F401 from lightning.pytorch.utilities.combined_loader import CombinedLoader # noqa: F401 +from lightning.pytorch.utilities.data import suggested_max_num_workers # noqa: F401 from lightning.pytorch.utilities.enums import GradClipAlgorithmType # noqa: F401 from lightning.pytorch.utilities.grads import grad_norm # noqa: F401 from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE, _TORCHVISION_AVAILABLE # noqa: F401 diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index 9e362fd72f5a4..d16d71957f3f4 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -40,6 +40,13 @@ def suggested_max_num_workers(local_world_size: int) -> int: + """Suggests an upper bound of num_workers to use in a PyTorch :class:`~torch.utils.data.DataLoader` based on + the number of CPU cores available on the system and the number of distributed processes in the current machine. + + Args: + local_world_size: The number of distributed processes running on the current machine. Set this to the number + of devices configured in the Trainer (``trainer.num_devices``). + """ if local_world_size < 1: raise ValueError(f"`local_world_size` should be >= 1, got {local_world_size}.") cpu_count = _num_cpus_available() From 6daa8f311ebd87b637b23dab18e3fef606610835 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 19 Sep 2023 17:37:46 +0200 Subject: [PATCH 09/28] x --- src/lightning/pytorch/utilities/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index d16d71957f3f4..e9bdbc085f25b 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -40,7 +40,7 @@ def suggested_max_num_workers(local_world_size: int) -> int: - """Suggests an upper bound of num_workers to use in a PyTorch :class:`~torch.utils.data.DataLoader` based on + """Suggests an upper bound of ``num_workers`` to use in a PyTorch :class:`~torch.utils.data.DataLoader` based on the number of CPU cores available on the system and the number of distributed processes in the current machine. Args: From d745a76cbe56d4015c8023079dbdbbf02f90f677 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Sep 2023 15:45:16 +0000 Subject: [PATCH 10/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/trainer/connectors/data_connector.py | 2 -- src/lightning/pytorch/utilities/data.py | 2 +- .../trainer/connectors/test_data_connector.py | 8 +++----- tests/tests_pytorch/utilities/test_data.py | 4 ++-- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index a6c29b0f16c89..688c8e3f28e7e 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing import os from dataclasses import dataclass, field from typing import Any, Iterable, Optional, Tuple, Union @@ -28,7 +27,6 @@ ) from lightning.fabric.utilities.distributed import DistributedSamplerWrapper from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSamplerWrapper -from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.trainer import call from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.combined_loader import CombinedLoader diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index e9bdbc085f25b..e992b6f327aad 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -54,7 +54,7 @@ def suggested_max_num_workers(local_world_size: int) -> int: def _num_cpus_available() -> int: - if hasattr(os, 'sched_getaffinity'): + if hasattr(os, "sched_getaffinity"): return len(os.sched_getaffinity(0)) cpu_count = os.cpu_count() diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 210bd553c8930..70533d9404097 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import redirect_stderr -from io import StringIO from re import escape from typing import Sized from unittest import mock @@ -31,8 +29,8 @@ _check_dataloader_iterable, _DataHookSelector, _DataLoaderSource, - warning_cache, _worker_check, + warning_cache, ) from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.combined_loader import CombinedLoader @@ -109,7 +107,7 @@ def test_dataloader(self): @pytest.mark.parametrize( - "num_devices, num_workers, cpu_count, expected_warning", + ("num_devices", "num_workers", "cpu_count", "expected_warning"), [ (1, 0, 1, False), (8, 0, 1, False), @@ -134,7 +132,7 @@ def test_worker_check(cpu_count_mock, num_devices, num_workers, cpu_count, expec cpu_count_mock.return_value = cpu_count if expected_warning: - ctx = pytest.warns(UserWarning, match=f"Consider increasing the value of the `num_workers` argument`") + ctx = pytest.warns(UserWarning, match="Consider increasing the value of the `num_workers` argument`") else: ctx = no_warning_call(UserWarning) diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index ea959a7a24625..f5735b5dce07c 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -20,8 +20,8 @@ _update_dataloader, extract_batch_size, has_len_all_ranks, - warning_cache, suggested_max_num_workers, + warning_cache, ) from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -332,7 +332,7 @@ def __init__(self, indices=None, **kwargs): @pytest.mark.parametrize( - "cpu_count, local_world_size, expected", + ("cpu_count", "local_world_size", "expected"), [ (0, 1, 1), (1, 1, 1), From 120c3538be24f99a2f42ee4cdf6b71fc059da905 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 19 Sep 2023 17:51:34 +0200 Subject: [PATCH 11/28] changelog --- src/lightning/pytorch/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index e2fbd3e63d849..d1f397467b8ea 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -124,6 +124,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for mixed 8-bit precision as `Trainer(precision="transformer-engine")` using [Nvidia's Transformer Engine](https://docs.nvidia.com/deeplearning/transformer-engine) ([#18459](https://github.com/Lightning-AI/lightning/pull/18459)) + +- Added `lightning.pytorch.utilities.suggested_max_num_workers` to assist with setting a good value in distributed settings ([#18591](https://github.com/Lightning-AI/lightning/pull/18591)) + + +- Improved the `num_workers` warning to give a more accurate upper limit on the `num_workers` suggestion ([#18591](https://github.com/Lightning-AI/lightning/pull/18591)) + + ### Changed - Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309)) From c1dccaee2b025d65a516989c50e81e9826d55845 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 19 Sep 2023 17:54:11 +0200 Subject: [PATCH 12/28] circular import --- src/lightning/pytorch/utilities/data.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index e992b6f327aad..3ab60f6e24d00 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -14,7 +14,7 @@ import inspect import os from dataclasses import fields -from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union +from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union, TYPE_CHECKING import torch from lightning_utilities.core.apply_func import is_dataclass_instance @@ -30,10 +30,12 @@ sized_len, ) from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper -from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache +if TYPE_CHECKING: + from lightning.pytorch.trainer.states import RunningStage + BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] warning_cache = WarningCache() @@ -150,7 +152,7 @@ def has_len_all_ranks( def _update_dataloader( - dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None + dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional["RunningStage"] = None ) -> DataLoader: dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode) return _reinstantiate_wrapped_cls(dataloader, *dl_args, **dl_kwargs) @@ -159,7 +161,7 @@ def _update_dataloader( def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, sampler: Union[Sampler, Iterable], - mode: Optional[RunningStage] = None, + mode: Optional["RunningStage"] = None, ) -> Tuple[Tuple[Any], Dict[str, Any]]: if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") @@ -253,7 +255,7 @@ def _get_dataloader_init_args_and_kwargs( def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, sampler: Union[Sampler, Iterable], - mode: Optional[RunningStage] = None, + mode: Optional["RunningStage"] = None, ) -> Dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re- instantiation. @@ -262,6 +264,8 @@ def _dataloader_init_kwargs_resolve_sampler( Lightning can keep track of its indices. """ + from lightning.pytorch.trainer.states import RunningStage + is_predicting = mode == RunningStage.PREDICTING batch_sampler = getattr(dataloader, "batch_sampler") batch_sampler_cls = type(batch_sampler) From dca5162d66266f1070e0ccc1eefe0a10c86827de Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 19 Sep 2023 17:54:11 +0200 Subject: [PATCH 13/28] Revert "circular import" This reverts commit c1dccaee2b025d65a516989c50e81e9826d55845. --- src/lightning/pytorch/utilities/data.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index 3ab60f6e24d00..e992b6f327aad 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -14,7 +14,7 @@ import inspect import os from dataclasses import fields -from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union, TYPE_CHECKING +from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union import torch from lightning_utilities.core.apply_func import is_dataclass_instance @@ -30,12 +30,10 @@ sized_len, ) from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper +from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache -if TYPE_CHECKING: - from lightning.pytorch.trainer.states import RunningStage - BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] warning_cache = WarningCache() @@ -152,7 +150,7 @@ def has_len_all_ranks( def _update_dataloader( - dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional["RunningStage"] = None + dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None ) -> DataLoader: dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode) return _reinstantiate_wrapped_cls(dataloader, *dl_args, **dl_kwargs) @@ -161,7 +159,7 @@ def _update_dataloader( def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, sampler: Union[Sampler, Iterable], - mode: Optional["RunningStage"] = None, + mode: Optional[RunningStage] = None, ) -> Tuple[Tuple[Any], Dict[str, Any]]: if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") @@ -255,7 +253,7 @@ def _get_dataloader_init_args_and_kwargs( def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, sampler: Union[Sampler, Iterable], - mode: Optional["RunningStage"] = None, + mode: Optional[RunningStage] = None, ) -> Dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re- instantiation. @@ -264,8 +262,6 @@ def _dataloader_init_kwargs_resolve_sampler( Lightning can keep track of its indices. """ - from lightning.pytorch.trainer.states import RunningStage - is_predicting = mode == RunningStage.PREDICTING batch_sampler = getattr(dataloader, "batch_sampler") batch_sampler_cls = type(batch_sampler) From 79267e45d0426ceb5952119174783611ff3a5a29 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 19 Sep 2023 17:59:12 +0200 Subject: [PATCH 14/28] refactor --- src/lightning/fabric/utilities/__init__.py | 1 + src/lightning/fabric/utilities/data.py | 22 ++++++++ .../trainer/connectors/data_connector.py | 3 +- src/lightning/pytorch/utilities/__init__.py | 2 +- src/lightning/pytorch/utilities/data.py | 23 --------- tests/tests_fabric/utilities/test_data.py | 50 ++++++++++++++++++- .../trainer/connectors/test_data_connector.py | 4 +- tests/tests_pytorch/utilities/test_data.py | 46 ----------------- 8 files changed, 77 insertions(+), 74 deletions(-) diff --git a/src/lightning/fabric/utilities/__init__.py b/src/lightning/fabric/utilities/__init__.py index 53bd2a4526612..c706d64463189 100644 --- a/src/lightning/fabric/utilities/__init__.py +++ b/src/lightning/fabric/utilities/__init__.py @@ -14,6 +14,7 @@ """General utilities.""" from lightning.fabric.utilities.apply_func import move_data_to_device # noqa: F401 +from lightning.fabric.utilities.data import suggested_max_num_workers # noqa: F401 from lightning.fabric.utilities.enums import LightningEnum # noqa: F401 from lightning.fabric.utilities.rank_zero import ( # noqa: F401 rank_zero_deprecation, diff --git a/src/lightning/fabric/utilities/data.py b/src/lightning/fabric/utilities/data.py index 7b3af926cfda0..ff37767623322 100644 --- a/src/lightning/fabric/utilities/data.py +++ b/src/lightning/fabric/utilities/data.py @@ -433,3 +433,25 @@ def _set_sampler_epoch(dataloader: object, epoch: int) -> None: set_epoch = getattr(obj, "set_epoch", None) if callable(set_epoch): set_epoch(epoch) + + +def suggested_max_num_workers(local_world_size: int) -> int: + """Suggests an upper bound of ``num_workers`` to use in a PyTorch :class:`~torch.utils.data.DataLoader` based on + the number of CPU cores available on the system and the number of distributed processes in the current machine. + + Args: + local_world_size: The number of distributed processes running on the current machine. Set this to the number + of devices configured in the Trainer (``trainer.num_devices``). + """ + if local_world_size < 1: + raise ValueError(f"`local_world_size` should be >= 1, got {local_world_size}.") + cpu_count = _num_cpus_available() + return max(1, cpu_count // local_world_size) + + +def _num_cpus_available() -> int: + if hasattr(os, "sched_getaffinity"): + return len(os.sched_getaffinity(0)) + + cpu_count = os.cpu_count() + return 1 if cpu_count is None else cpu_count diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 688c8e3f28e7e..4a8ea2efe30b9 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -24,13 +24,14 @@ _replace_dunder_methods, _set_sampler_epoch, has_iterable_dataset, + suggested_max_num_workers, ) from lightning.fabric.utilities.distributed import DistributedSamplerWrapper from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSamplerWrapper from lightning.pytorch.trainer import call from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.combined_loader import CombinedLoader -from lightning.pytorch.utilities.data import _is_dataloader_shuffled, _update_dataloader, suggested_max_num_workers +from lightning.pytorch.utilities.data import _is_dataloader_shuffled, _update_dataloader from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _lightning_graphcore_available from lightning.pytorch.utilities.model_helpers import is_overridden diff --git a/src/lightning/pytorch/utilities/__init__.py b/src/lightning/pytorch/utilities/__init__.py index 645fd9545ebc7..699120da7ea36 100644 --- a/src/lightning/pytorch/utilities/__init__.py +++ b/src/lightning/pytorch/utilities/__init__.py @@ -17,8 +17,8 @@ from lightning.fabric.utilities import LightningEnum # noqa: F401 from lightning.fabric.utilities import move_data_to_device # noqa: F401 +from lightning.fabric.utilities import suggested_max_num_workers # noqa: F401 from lightning.pytorch.utilities.combined_loader import CombinedLoader # noqa: F401 -from lightning.pytorch.utilities.data import suggested_max_num_workers # noqa: F401 from lightning.pytorch.utilities.enums import GradClipAlgorithmType # noqa: F401 from lightning.pytorch.utilities.grads import grad_norm # noqa: F401 from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE, _TORCHVISION_AVAILABLE # noqa: F401 diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index e992b6f327aad..b026badb65c8a 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -import os from dataclasses import fields from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union @@ -39,28 +38,6 @@ warning_cache = WarningCache() -def suggested_max_num_workers(local_world_size: int) -> int: - """Suggests an upper bound of ``num_workers`` to use in a PyTorch :class:`~torch.utils.data.DataLoader` based on - the number of CPU cores available on the system and the number of distributed processes in the current machine. - - Args: - local_world_size: The number of distributed processes running on the current machine. Set this to the number - of devices configured in the Trainer (``trainer.num_devices``). - """ - if local_world_size < 1: - raise ValueError(f"`local_world_size` should be >= 1, got {local_world_size}.") - cpu_count = _num_cpus_available() - return max(1, cpu_count // local_world_size) - - -def _num_cpus_available() -> int: - if hasattr(os, "sched_getaffinity"): - return len(os.sched_getaffinity(0)) - - cpu_count = os.cpu_count() - return 1 if cpu_count is None else cpu_count - - def _extract_batch_size(batch: BType) -> Generator[Optional[int], None, None]: if isinstance(batch, Tensor): if batch.ndim == 0: diff --git a/tests/tests_fabric/utilities/test_data.py b/tests/tests_fabric/utilities/test_data.py index 072d7a545677a..70fc6905cf563 100644 --- a/tests/tests_fabric/utilities/test_data.py +++ b/tests/tests_fabric/utilities/test_data.py @@ -1,5 +1,7 @@ import contextlib +import os import random +from unittest import mock from unittest.mock import Mock import numpy as np @@ -16,7 +18,7 @@ _update_dataloader, _WrapAttrTag, has_iterable_dataset, - has_len, + has_len, suggested_max_num_workers, ) from lightning.fabric.utilities.exceptions import MisconfigurationException from tests_fabric.helpers.models import RandomDataset, RandomIterableDataset @@ -575,3 +577,49 @@ def test_set_sampler_epoch(): _set_sampler_epoch(dataloader, 55) dataloader.sampler.set_epoch.assert_called_once_with(55) dataloader.batch_sampler.sampler.set_epoch.assert_called_once_with(55) + + +@pytest.mark.parametrize( + ("cpu_count", "local_world_size", "expected"), + [ + (0, 1, 1), + (1, 1, 1), + (2, 1, 2), + (1, 2, 1), + (1, 2, 1), + (2, 2, 1), + (3, 2, 1), + (4, 2, 2), + (4, 3, 1), + (4, 1, 4), + ], +) +@pytest.mark.parametrize( + "affinity", + [ + False, + pytest.param( + True, + marks=pytest.mark.skipif( + not hasattr(os, "sched_getaffinity"), reason="OS does not support restricting CPU cores" + ), + ), + ], +) +@mock.patch("lightning.fabric.utilities.data.os.cpu_count") +def test_suggested_max_num_workers(cpu_count_mock, affinity, cpu_count, local_world_size, expected, monkeypatch): + if affinity: + monkeypatch.setattr( + "lightning.fabric.utilities.data.os", "sched_getaffinity", lambda _: list(range(cpu_count)) + ) + else: + monkeypatch.delattr("lightning.pytorch.utilities.data.os", "sched_getaffinity", raising=False) + cpu_count_mock.return_value = cpu_count + + assert suggested_max_num_workers(local_world_size) == expected + + +@pytest.mark.parametrize("invalid", [-1, 0]) +def test_suggested_max_num_workers_input_validation(invalid): + with pytest.raises(ValueError, match="should be >= 1"): + suggested_max_num_workers(invalid) diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 70533d9404097..7100a45f231e6 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -122,9 +122,9 @@ def test_dataloader(self): (8, 2, 8, False), ], ) -@mock.patch("lightning.pytorch.utilities.data.os.cpu_count") +@mock.patch("lightning.fabric.utilities.data.os.cpu_count") def test_worker_check(cpu_count_mock, num_devices, num_workers, cpu_count, expected_warning, monkeypatch): - monkeypatch.delattr("lightning.pytorch.utilities.data.os", "sched_getaffinity", raising=False) + monkeypatch.delattr("lightning.fabric.utilities.data.os", "sched_getaffinity", raising=False) trainer = Mock(spec=Trainer) dataloader = Mock(spec=DataLoader) trainer.num_devices = num_devices diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index f5735b5dce07c..e4acf5c84ef24 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -329,49 +329,3 @@ def __init__(self, indices=None, **kwargs): dataloader = ArrayAttributeDataloader(dataset) dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler) assert dl_kwargs["indices"] is dataloader.indices - - -@pytest.mark.parametrize( - ("cpu_count", "local_world_size", "expected"), - [ - (0, 1, 1), - (1, 1, 1), - (2, 1, 2), - (1, 2, 1), - (1, 2, 1), - (2, 2, 1), - (3, 2, 1), - (4, 2, 2), - (4, 3, 1), - (4, 1, 4), - ], -) -@pytest.mark.parametrize( - "affinity", - [ - False, - pytest.param( - True, - marks=pytest.mark.skipif( - not hasattr(os, "sched_getaffinity"), reason="OS does not support restricting CPU cores" - ), - ), - ], -) -@mock.patch("lightning.pytorch.utilities.data.os.cpu_count") -def test_suggested_max_num_workers(cpu_count_mock, affinity, cpu_count, local_world_size, expected, monkeypatch): - if affinity: - monkeypatch.setattr( - "lightning.pytorch.utilities.data.os", "sched_getaffinity", lambda _: list(range(cpu_count)) - ) - else: - monkeypatch.delattr("lightning.pytorch.utilities.data.os", "sched_getaffinity", raising=False) - cpu_count_mock.return_value = cpu_count - - assert suggested_max_num_workers(local_world_size) == expected - - -@pytest.mark.parametrize("invalid", [-1, 0]) -def test_suggested_max_num_workers_input_validation(invalid): - with pytest.raises(ValueError, match="should be >= 1"): - suggested_max_num_workers(invalid) From 9228d647525e4e4c7dba6cb7d6696c5814b16af7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 19 Sep 2023 18:02:34 +0200 Subject: [PATCH 15/28] update --- src/lightning/fabric/utilities/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/utilities/data.py b/src/lightning/fabric/utilities/data.py index ff37767623322..f10b2ddf98d12 100644 --- a/src/lightning/fabric/utilities/data.py +++ b/src/lightning/fabric/utilities/data.py @@ -441,7 +441,7 @@ def suggested_max_num_workers(local_world_size: int) -> int: Args: local_world_size: The number of distributed processes running on the current machine. Set this to the number - of devices configured in the Trainer (``trainer.num_devices``). + of devices configured in the Fabric/Trainer. """ if local_world_size < 1: raise ValueError(f"`local_world_size` should be >= 1, got {local_world_size}.") From 8055027340b37ab5209253f577362f039753e7d3 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 19 Sep 2023 18:03:04 +0200 Subject: [PATCH 16/28] chlog --- src/lightning/fabric/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 24babb09ce616..5964d3f667473 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -124,6 +124,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for saving and loading stateful objects other than modules and optimizers ([#18513](https://github.com/Lightning-AI/lightning/pull/18513)) +- Added `lightning.fabrioc.utilities.suggested_max_num_workers` to assist with setting a good value in distributed settings ([#18591](https://github.com/Lightning-AI/lightning/pull/18591)) + + ### Changed - Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331)) From 22c1d2522e59b66e6357e2dc6dcfbcebb6464574 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Sep 2023 16:03:23 +0000 Subject: [PATCH 17/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/utilities/test_data.py | 7 +++---- tests/tests_pytorch/utilities/test_data.py | 3 --- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/tests_fabric/utilities/test_data.py b/tests/tests_fabric/utilities/test_data.py index 70fc6905cf563..24260e84da876 100644 --- a/tests/tests_fabric/utilities/test_data.py +++ b/tests/tests_fabric/utilities/test_data.py @@ -18,7 +18,8 @@ _update_dataloader, _WrapAttrTag, has_iterable_dataset, - has_len, suggested_max_num_workers, + has_len, + suggested_max_num_workers, ) from lightning.fabric.utilities.exceptions import MisconfigurationException from tests_fabric.helpers.models import RandomDataset, RandomIterableDataset @@ -609,9 +610,7 @@ def test_set_sampler_epoch(): @mock.patch("lightning.fabric.utilities.data.os.cpu_count") def test_suggested_max_num_workers(cpu_count_mock, affinity, cpu_count, local_world_size, expected, monkeypatch): if affinity: - monkeypatch.setattr( - "lightning.fabric.utilities.data.os", "sched_getaffinity", lambda _: list(range(cpu_count)) - ) + monkeypatch.setattr("lightning.fabric.utilities.data.os", "sched_getaffinity", lambda _: list(range(cpu_count))) else: monkeypatch.delattr("lightning.pytorch.utilities.data.os", "sched_getaffinity", raising=False) cpu_count_mock.return_value = cpu_count diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index e4acf5c84ef24..e62a43ee14e8e 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -1,6 +1,4 @@ -import os from dataclasses import dataclass -from unittest import mock from unittest.mock import Mock import numpy as np @@ -20,7 +18,6 @@ _update_dataloader, extract_batch_size, has_len_all_ranks, - suggested_max_num_workers, warning_cache, ) from lightning.pytorch.utilities.exceptions import MisconfigurationException From 388e25a4171b7cf9c5a29fab842680fa8990333b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 19 Sep 2023 19:57:02 +0200 Subject: [PATCH 18/28] fix test --- src/lightning/pytorch/trainer/connectors/data_connector.py | 2 +- tests/tests_pytorch/trainer/test_dataloaders.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 4a8ea2efe30b9..2163652a8a461 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -428,7 +428,7 @@ def _worker_check(trainer: "pl.Trainer", dataloader: object, name: str) -> None: # TODO # if changed, update the `filterwarnings` snippet in 'speed.html#num-workers' rank_zero_warn( - f"The '{name}', does not have many workers which may be a bottleneck. Consider increasing the value of the" + f"The '{name}' does not have many workers which may be a bottleneck. Consider increasing the value of the" f" `num_workers` argument` to `num_workers={upper_bound}` in the `DataLoader` to improve performance.", category=PossibleUserWarning, ) diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index a0cd3fdf74cdb..bbd5f39d2ecff 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -561,9 +561,9 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): @RunIf(skip_windows=True) @pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) @pytest.mark.parametrize("stage", ["train", "test", "val"]) -@patch("lightning.pytorch.trainer.connectors.data_connector.multiprocessing.cpu_count", return_value=4) +@patch("lightning.fabric.utilities.data._num_cpus_available", return_value=4) def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): - """Test that error is raised if dataloader with only a few workers is used.""" + """Test that a warning is emitted if the dataloader only has a few workers.""" class CustomModel(MultiEvalDataLoaderModel): def training_step(self, batch, batch_idx): @@ -584,7 +584,7 @@ def training_step(self, batch, batch_idx): with pytest.warns( UserWarning, - match=f"The dataloader, {stage}_dataloader, does not have many workers", + match=f"The '{stage}_dataloader' does not have many workers", ): if stage == "test": if ckpt_path in ("specific", "best"): From d35cb676410e0220f39e897347835b47e233fdd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Sep 2023 15:53:15 -0400 Subject: [PATCH 19/28] Update src/lightning/fabric/CHANGELOG.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- src/lightning/fabric/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 5964d3f667473..351a713fbf150 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -124,7 +124,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for saving and loading stateful objects other than modules and optimizers ([#18513](https://github.com/Lightning-AI/lightning/pull/18513)) -- Added `lightning.fabrioc.utilities.suggested_max_num_workers` to assist with setting a good value in distributed settings ([#18591](https://github.com/Lightning-AI/lightning/pull/18591)) +- Added `lightning.fabric.utilities.suggested_max_num_workers` to assist with setting a good value in distributed settings ([#18591](https://github.com/Lightning-AI/lightning/pull/18591)) ### Changed From 47258a82b961bc2c4063748b38ee70902b75b3b4 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 21 Sep 2023 14:00:34 +0200 Subject: [PATCH 20/28] keep the spawn warnings for now, will remove in separate PR --- .../trainer/connectors/data_connector.py | 30 +++++++++++-- .../trainer/connectors/test_data_connector.py | 42 +++++++++++++++++++ 2 files changed, 68 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 2163652a8a461..3d3fe2769c370 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -28,6 +28,7 @@ ) from lightning.fabric.utilities.distributed import DistributedSamplerWrapper from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSamplerWrapper +from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.trainer import call from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.combined_loader import CombinedLoader @@ -419,13 +420,29 @@ def _check_dataloader_iterable( ) -def _worker_check(trainer: "pl.Trainer", dataloader: object, name: str) -> None: +def _worker_check(trainer: "pl.Trainer", using_spawn: bool, dataloader: object, name: str) -> None: if not isinstance(dataloader, DataLoader): return upper_bound = suggested_max_num_workers(trainer.num_devices) - if dataloader.num_workers <= 2 < upper_bound or dataloader.num_workers < 2 <= upper_bound: - # TODO + + # ddp_spawn + num_workers > 0 don't mix! tell the user + if dataloader.num_workers > 0 and using_spawn: + if not dataloader.persistent_workers: + rank_zero_warn( + "num_workers>0, persistent_workers=False, and strategy=ddp_spawn" + " may result in data loading bottlenecks." + " Consider setting persistent_workers=True" + " (this is a limitation of Python .spawn() and PyTorch)" + ) + + elif dataloader.num_workers == 0 and using_spawn: + if not dataloader.persistent_workers: + rank_zero_warn( + "strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks." + " Consider setting num_workers>0 and persistent_workers=True" + ) + elif dataloader.num_workers <= 2 < upper_bound or dataloader.num_workers < 2 <= upper_bound: # if changed, update the `filterwarnings` snippet in 'speed.html#num-workers' rank_zero_warn( f"The '{name}' does not have many workers which may be a bottleneck. Consider increasing the value of the" @@ -486,7 +503,12 @@ def _process_dataloader( dataloader = strategy.process_dataloader(dataloader) # check the workers - _worker_check(trainer, dataloader, f"{stage.dataloader_prefix}_dataloader") + _worker_check( + trainer=trainer, + using_spawn=isinstance(strategy, DDPStrategy) and strategy._start_method == "spawn", + dataloader=dataloader, + name=f"{stage.dataloader_prefix}_dataloader", + ) # add worker_init_fn for correct seeding in worker processes _auto_add_worker_init_fn(dataloader, trainer.global_rank) diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 7100a45f231e6..e9a87b1e58074 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import redirect_stderr +from io import StringIO from re import escape from typing import Sized from unittest import mock @@ -106,6 +108,46 @@ def test_dataloader(self): trainer.test(model) +class TestSpawnBoringModel(BoringModel): + def __init__(self, num_workers): + super().__init__() + self.num_workers = num_workers + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64), num_workers=self.num_workers) + + def on_fit_start(self): + self._resout = StringIO() + self.ctx = redirect_stderr(self._resout) + self.ctx.__enter__() + + def on_train_end(self): + def _get_warning_msg(): + dl = self.trainer.train_dataloader + if hasattr(dl, "persistent_workers"): + if self.num_workers == 0: + warn_str = "Consider setting num_workers>0 and persistent_workers=True" + else: + warn_str = "Consider setting persistent_workers=True" + else: + warn_str = "Consider setting strategy=ddp" + + return warn_str + + if self.trainer.is_global_zero: + self.ctx.__exit__(None, None, None) + msg = self._resout.getvalue() + warn_str = _get_warning_msg() + assert warn_str in msg + + +@RunIf(skip_windows=True) +@pytest.mark.parametrize("num_workers", [0, 1]) +def test_dataloader_warnings(tmpdir, num_workers): + trainer = Trainer(default_root_dir=tmpdir, accelerator="cpu", devices=2, strategy="ddp_spawn", fast_dev_run=4) + trainer.fit(TestSpawnBoringModel(num_workers)) + + @pytest.mark.parametrize( ("num_devices", "num_workers", "cpu_count", "expected_warning"), [ From 747b62b0387c7c439e30fb437e70883388a57241 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 21 Sep 2023 14:02:23 +0200 Subject: [PATCH 21/28] typo --- src/lightning/fabric/utilities/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/utilities/data.py b/src/lightning/fabric/utilities/data.py index f10b2ddf98d12..60e666cc4e981 100644 --- a/src/lightning/fabric/utilities/data.py +++ b/src/lightning/fabric/utilities/data.py @@ -441,7 +441,7 @@ def suggested_max_num_workers(local_world_size: int) -> int: Args: local_world_size: The number of distributed processes running on the current machine. Set this to the number - of devices configured in the Fabric/Trainer. + of devices configured in Fabric/Trainer. """ if local_world_size < 1: raise ValueError(f"`local_world_size` should be >= 1, got {local_world_size}.") From bbf7782cca5316690828080fd6d8ae8ce7e0d0e1 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 21 Sep 2023 14:16:57 +0200 Subject: [PATCH 22/28] add requested test --- tests/tests_fabric/utilities/test_data.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/utilities/test_data.py b/tests/tests_fabric/utilities/test_data.py index 24260e84da876..2b1865eef330e 100644 --- a/tests/tests_fabric/utilities/test_data.py +++ b/tests/tests_fabric/utilities/test_data.py @@ -9,7 +9,9 @@ import torch from torch import Tensor from torch.utils.data import BatchSampler, DataLoader, RandomSampler +from lightning_utilities.test.warning import no_warning_call +import lightning.fabric from lightning.fabric.utilities.data import ( _get_dataloader_init_args_and_kwargs, _replace_dunder_methods, @@ -610,9 +612,9 @@ def test_set_sampler_epoch(): @mock.patch("lightning.fabric.utilities.data.os.cpu_count") def test_suggested_max_num_workers(cpu_count_mock, affinity, cpu_count, local_world_size, expected, monkeypatch): if affinity: - monkeypatch.setattr("lightning.fabric.utilities.data.os", "sched_getaffinity", lambda _: list(range(cpu_count))) + monkeypatch.setattr(lightning.fabric.utilities.data.os, "sched_getaffinity", lambda _: list(range(cpu_count))) else: - monkeypatch.delattr("lightning.pytorch.utilities.data.os", "sched_getaffinity", raising=False) + monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False) cpu_count_mock.return_value = cpu_count assert suggested_max_num_workers(local_world_size) == expected @@ -622,3 +624,15 @@ def test_suggested_max_num_workers(cpu_count_mock, affinity, cpu_count, local_wo def test_suggested_max_num_workers_input_validation(invalid): with pytest.raises(ValueError, match="should be >= 1"): suggested_max_num_workers(invalid) + + +@pytest.mark.parametrize("cpu_count", [1, 2, 3]) +@pytest.mark.parametrize("local_world_size", [1, 2, 3]) +def test_suggested_max_num_workers_not_triggering_torch_warning(local_world_size, cpu_count, monkeypatch): + """Test that our suggestion for num workers doesn't trigger a warning in the DataLoader for too many workers.""" + monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False) + monkeypatch.delattr(torch.utils.data.dataloader.os, "sched_getaffinity", raising=False) + monkeypatch.setattr(lightning.fabric.utilities.data.os, "cpu_count", lambda: cpu_count) + monkeypatch.setattr(torch.utils.data.dataloader.os, "cpu_count", lambda: cpu_count) + with no_warning_call(): + DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size)) From 8462322c107cfc1738b607a9f88b9931b4c90289 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 21 Sep 2023 14:17:44 +0200 Subject: [PATCH 23/28] add comment --- tests/tests_fabric/utilities/test_data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_fabric/utilities/test_data.py b/tests/tests_fabric/utilities/test_data.py index 2b1865eef330e..9e549c34de470 100644 --- a/tests/tests_fabric/utilities/test_data.py +++ b/tests/tests_fabric/utilities/test_data.py @@ -635,4 +635,5 @@ def test_suggested_max_num_workers_not_triggering_torch_warning(local_world_size monkeypatch.setattr(lightning.fabric.utilities.data.os, "cpu_count", lambda: cpu_count) monkeypatch.setattr(torch.utils.data.dataloader.os, "cpu_count", lambda: cpu_count) with no_warning_call(): + # The dataloader runs a check in `DataLoader.check_worker_number_rationality` DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size)) From 6b7066d45a091ccc9fa08616a005d2d841c0d581 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 21 Sep 2023 14:19:46 +0200 Subject: [PATCH 24/28] extend test --- tests/tests_fabric/utilities/test_data.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/tests_fabric/utilities/test_data.py b/tests/tests_fabric/utilities/test_data.py index 9e549c34de470..55647e0eb72d2 100644 --- a/tests/tests_fabric/utilities/test_data.py +++ b/tests/tests_fabric/utilities/test_data.py @@ -634,6 +634,9 @@ def test_suggested_max_num_workers_not_triggering_torch_warning(local_world_size monkeypatch.delattr(torch.utils.data.dataloader.os, "sched_getaffinity", raising=False) monkeypatch.setattr(lightning.fabric.utilities.data.os, "cpu_count", lambda: cpu_count) monkeypatch.setattr(torch.utils.data.dataloader.os, "cpu_count", lambda: cpu_count) + + # The dataloader runs a check in `DataLoader.check_worker_number_rationality` + with pytest.warns(UserWarning, match="This DataLoader will create"): + DataLoader(range(2), num_workers=(cpu_count + 1)) with no_warning_call(): - # The dataloader runs a check in `DataLoader.check_worker_number_rationality` DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size)) From b1a451c34e19542db8beeeab41252a393b66ae41 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 21 Sep 2023 14:21:57 +0200 Subject: [PATCH 25/28] test fixes --- .../tests_pytorch/trainer/connectors/test_data_connector.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index e9a87b1e58074..d9aaff068764a 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -23,6 +23,7 @@ from torch import Tensor from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler +import lightning.fabric from lightning.fabric.utilities.distributed import DistributedSamplerWrapper from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer @@ -166,7 +167,7 @@ def test_dataloader_warnings(tmpdir, num_workers): ) @mock.patch("lightning.fabric.utilities.data.os.cpu_count") def test_worker_check(cpu_count_mock, num_devices, num_workers, cpu_count, expected_warning, monkeypatch): - monkeypatch.delattr("lightning.fabric.utilities.data.os", "sched_getaffinity", raising=False) + monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False) trainer = Mock(spec=Trainer) dataloader = Mock(spec=DataLoader) trainer.num_devices = num_devices @@ -179,7 +180,7 @@ def test_worker_check(cpu_count_mock, num_devices, num_workers, cpu_count, expec ctx = no_warning_call(UserWarning) with ctx: - _worker_check(trainer, dataloader, "train_dataloader") + _worker_check(trainer, using_spawn=False, dataloader=dataloader, name="train_dataloader") def test_update_dataloader_raises(): From 49eef42338e1f7ec8951a3d76a254545971e9b9d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Sep 2023 12:25:29 +0000 Subject: [PATCH 26/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/utilities/test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/utilities/test_data.py b/tests/tests_fabric/utilities/test_data.py index 55647e0eb72d2..5cfcb8e747a85 100644 --- a/tests/tests_fabric/utilities/test_data.py +++ b/tests/tests_fabric/utilities/test_data.py @@ -7,9 +7,9 @@ import numpy as np import pytest import torch +from lightning_utilities.test.warning import no_warning_call from torch import Tensor from torch.utils.data import BatchSampler, DataLoader, RandomSampler -from lightning_utilities.test.warning import no_warning_call import lightning.fabric from lightning.fabric.utilities.data import ( From acbdb01f59d8c43833f4bb4d7f74a08bcb12c87f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 21 Sep 2023 14:25:48 +0200 Subject: [PATCH 27/28] add utility to api docs --- docs/source-fabric/api/utilities.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source-fabric/api/utilities.rst b/docs/source-fabric/api/utilities.rst index b4bd1f564131c..bf23827b6dfe8 100644 --- a/docs/source-fabric/api/utilities.rst +++ b/docs/source-fabric/api/utilities.rst @@ -9,3 +9,5 @@ lightning.fabric.utilities .. autofunction:: lightning.fabric.utilities.seed.seed_everything .. autofunction:: lightning.fabric.utilities.seed.pl_worker_init_function + +.. autofunction:: lightning.fabric.utilities.data.suggested_max_num_workers From 0a9dd287f43ce910b78e58247e3e33938a50c3e3 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 21 Sep 2023 14:43:50 +0200 Subject: [PATCH 28/28] fix test --- tests/tests_pytorch/trainer/test_dataloaders.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index bbd5f39d2ecff..19515cd5399e5 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -532,7 +532,7 @@ def test_warning_on_zero_len_dataloader(): @RunIf(skip_windows=True) @pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) @pytest.mark.parametrize("stage", ["train", "test", "val"]) -@patch("lightning.pytorch.trainer.connectors.data_connector.multiprocessing.cpu_count", return_value=4) +@patch("lightning.fabric.utilities.data._num_cpus_available", return_value=4) def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): """Test that error is raised if dataloader with only a few workers is used.""" model = BoringModel() @@ -545,10 +545,7 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2) - with pytest.warns( - UserWarning, - match=f"The dataloader, {stage}_dataloader, does not have many workers", - ): + with pytest.warns(UserWarning, match=f"The '{stage}_dataloader' does not have many workers"): if stage == "test": if ckpt_path in ("specific", "best"): trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)