Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the suggested num_workers warning #18591

Merged
merged 32 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
2ca329e
add suggested num workers
awaelchli Sep 19, 2023
2db3680
wip
awaelchli Sep 19, 2023
c5c5de2
update
awaelchli Sep 19, 2023
ada09ed
test
awaelchli Sep 19, 2023
b6dfaea
update
awaelchli Sep 19, 2023
4dff113
test
awaelchli Sep 19, 2023
63628f4
format
awaelchli Sep 19, 2023
84f6dea
docs
awaelchli Sep 19, 2023
6daa8f3
x
awaelchli Sep 19, 2023
d745a76
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2023
120c353
changelog
awaelchli Sep 19, 2023
0467092
Merge remote-tracking branch 'origin/feature/num-workers' into featur…
awaelchli Sep 19, 2023
c1dccae
circular import
awaelchli Sep 19, 2023
dca5162
Revert "circular import"
awaelchli Sep 19, 2023
79267e4
refactor
awaelchli Sep 19, 2023
9228d64
update
awaelchli Sep 19, 2023
8055027
chlog
awaelchli Sep 19, 2023
22c1d25
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2023
388e25a
fix test
awaelchli Sep 19, 2023
d35cb67
Update src/lightning/fabric/CHANGELOG.md
awaelchli Sep 19, 2023
311348b
Merge remote-tracking branch 'origin/feature/num-workers' into featur…
awaelchli Sep 21, 2023
47258a8
keep the spawn warnings for now, will remove in separate PR
awaelchli Sep 21, 2023
747b62b
typo
awaelchli Sep 21, 2023
bbf7782
add requested test
awaelchli Sep 21, 2023
8462322
add comment
awaelchli Sep 21, 2023
6b7066d
extend test
awaelchli Sep 21, 2023
b1a451c
test fixes
awaelchli Sep 21, 2023
32797cd
Merge branch 'master' into feature/num-workers
awaelchli Sep 21, 2023
49eef42
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2023
acbdb01
add utility to api docs
awaelchli Sep 21, 2023
be9b46c
Merge remote-tracking branch 'origin/feature/num-workers' into featur…
awaelchli Sep 21, 2023
0a9dd28
fix test
awaelchli Sep 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source-fabric/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ lightning.fabric.utilities
.. autofunction:: lightning.fabric.utilities.seed.seed_everything

.. autofunction:: lightning.fabric.utilities.seed.pl_worker_init_function

.. autofunction:: lightning.fabric.utilities.data.suggested_max_num_workers
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled the default process group configuration for FSDP's hybrid sharding ([#18583](https://github.com/Lightning-AI/lightning/pull/18583))


- Added `lightning.fabric.utilities.suggested_max_num_workers` to assist with setting a good value in distributed settings ([#18591](https://github.com/Lightning-AI/lightning/pull/18591))


### Changed

- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))
Expand Down
1 change: 1 addition & 0 deletions src/lightning/fabric/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""General utilities."""

from lightning.fabric.utilities.apply_func import move_data_to_device # noqa: F401
from lightning.fabric.utilities.data import suggested_max_num_workers # noqa: F401
from lightning.fabric.utilities.enums import LightningEnum # noqa: F401
from lightning.fabric.utilities.rank_zero import ( # noqa: F401
rank_zero_deprecation,
Expand Down
22 changes: 22 additions & 0 deletions src/lightning/fabric/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,3 +433,25 @@ def _set_sampler_epoch(dataloader: object, epoch: int) -> None:
set_epoch = getattr(obj, "set_epoch", None)
if callable(set_epoch):
set_epoch(epoch)


def suggested_max_num_workers(local_world_size: int) -> int:
"""Suggests an upper bound of ``num_workers`` to use in a PyTorch :class:`~torch.utils.data.DataLoader` based on
the number of CPU cores available on the system and the number of distributed processes in the current machine.

Args:
local_world_size: The number of distributed processes running on the current machine. Set this to the number
of devices configured in Fabric/Trainer.
"""
if local_world_size < 1:
raise ValueError(f"`local_world_size` should be >= 1, got {local_world_size}.")
cpu_count = _num_cpus_available()
return max(1, cpu_count // local_world_size)
Copy link
Contributor

@stas00 stas00 Oct 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need at least one cpu-core per gpu-bound process, so I think would make a more sensible recommendation:

return max(1, (cpu_count // local_world_size)-1)

i.e. if you have 48 cpu-cores, and 8 gpus, you need at least 8 cores for the main processes, so only 40 remain then - so 40/8=5

and then there is the OS as well, which needs at least a core or 2 for its functioning.



def _num_cpus_available() -> int:
if hasattr(os, "sched_getaffinity"):
return len(os.sched_getaffinity(0))

cpu_count = os.cpu_count()
return 1 if cpu_count is None else cpu_count
7 changes: 7 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled the default process group configuration for FSDP's hybrid sharding ([#18583](https://github.com/Lightning-AI/lightning/pull/18583))



- Added `lightning.pytorch.utilities.suggested_max_num_workers` to assist with setting a good value in distributed settings ([#18591](https://github.com/Lightning-AI/lightning/pull/18591))


- Improved the `num_workers` warning to give a more accurate upper limit on the `num_workers` suggestion ([#18591](https://github.com/Lightning-AI/lightning/pull/18591))


### Changed

- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))
Expand Down
22 changes: 10 additions & 12 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import os
from dataclasses import dataclass, field
from typing import Any, Iterable, Optional, Tuple, Union
Expand All @@ -25,6 +24,7 @@
_replace_dunder_methods,
_set_sampler_epoch,
has_iterable_dataset,
suggested_max_num_workers,
)
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSamplerWrapper
Expand Down Expand Up @@ -420,11 +420,11 @@ def _check_dataloader_iterable(
)


def _worker_check(dataloader: object, using_spawn: bool, name: str) -> None:
def _worker_check(trainer: "pl.Trainer", using_spawn: bool, dataloader: object, name: str) -> None:
if not isinstance(dataloader, DataLoader):
return

num_cpus = multiprocessing.cpu_count()
upper_bound = suggested_max_num_workers(trainer.num_devices)

# ddp_spawn + num_workers > 0 don't mix! tell the user
if dataloader.num_workers > 0 and using_spawn:
Expand All @@ -442,14 +442,11 @@ def _worker_check(dataloader: object, using_spawn: bool, name: str) -> None:
"strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks."
" Consider setting num_workers>0 and persistent_workers=True"
)

elif dataloader.num_workers <= 2 < num_cpus and not using_spawn:
elif dataloader.num_workers <= 2 < upper_bound or dataloader.num_workers < 2 <= upper_bound:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the first part is pretty much certain to be True on the majority of setups I think, since the default in many frameworks is usually 2 workers and most modern machines have lots of cpu-cores for desktops, and any serious gpus nodes will have lots of cpu-cores.

Would you consider that num_workers == 2 is actually a reasonable setting?

Copy link
Contributor Author

@awaelchli awaelchli Oct 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I read you correctly, you suggest simplifying the condition by dropping 2 < upper_bound because it's just true on most DL systems. I think I agree. So we would be left with only this check:

dataloader.num_workers <= 2
Now the question is, should this become a different upper limit. The history here is that when this was added back when PL was developed, the majority of users were doing computer vision and num workers > 2 was almost always mandatory for applying augmentations.

Today if we do small to medium size LLM training on small pre-tokenized datasets that fit in a machine, we're not going to require that many workers. So num workers=2 is not that unreasonable. We could consider dropping that to

dataloader.num_workers < 2
to emit the warning.

Copy link
Contributor

@stas00 stas00 Oct 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are in agreement, Adrian

  • language models, especially where data is preprocessed ahead of time, ala Megatron-LM/nemo - 2 workers is the norm from my experience
  • vision models is a different story, if the amount of transforms is huge it can be very slow and yes warrant more workers

Ideally the general solution would be to provide users with a breakdown of timing for [DL][fwd+bwd][logging] time spans, so that they can see if their DL is a bottleneck. I include logging as well since it can also be an issue if one does some blocking logging to a slow or remote IO.

Now you're switching to measuring the real impact of num_workers to the efficiency of the compute, as compared to the guesswork that is implemented now.

So for example you could warn a user if you see that dl_time > 10% of compute time or something like that. I don't have a good general recommendation on threshold % since LLM/VLM workload could be very different to a 1 gpu workload. But the principle is the same - compute is the most expensive resource, especially if it's rented per hour, so now the user wants to detect and remove the bottlenecks to pay less and of course have an earlier finish line.


Now, with the raise of the cloud storage there is a new problem emerging and that is of the data not even being present on the compute node at the time of compute. So not only DL might need to do transforms, it also needs to fetch remote data from the cloud. Here 2 workers are almost never enough. During IDEFICS-80B training we had this issue, but we couldn't raise the number of workers not because we didn't have the spare cores, but because WebDataset which we were using for Just-in-time prefetch lead to huge processes so that the 2x8 workers were consuming a lion part of 1+TB of CPU memory and so with even 3 workers we would get cgroups killing the training.

I hope that WebDataset and other streaming DL solutions will get better over time, but this was a very painful experience 6 months ago, since we did want more workers, but couldn't have them.
I think H100 nodes are coming with at least 2TB of CPU in some clouds so this should help. But then H100s run 2-6 faster than A100s, so one needs to feed the fire even faster, which means more workers will be needed! and so the rule of at least 2 might no longer apply again.

In other words load-based heuristics (like I proposed above) are needed to really help the users to optimize their setup and one-fit-all guessing heuristics will break even more so.

# if changed, update the `filterwarnings` snippet in 'speed.html#num-workers'
rank_zero_warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since our bounds now match PyTorch's, should we filter https://github.com/pytorch/pytorch/blob/v2.0.1/torch/utils/data/dataloader.py#L488-L563 so that users don't get the warning twice? Alternatively, we could remove ours

Copy link
Contributor Author

@awaelchli awaelchli Sep 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No the warning in the PyTorch dataloader is the opposite: If the user selects an excessive amount of workers, then the PyTorch warning triggers. Our warning is only being triggered if there are fewer than our minimum suggested workers (and we suggest no more than the upper limit defined by PyTorch). These warnings no longer overlap (lemme know if you find and edge case) with each other as reported in #15572 and so this is fixed now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think you could assert that the dataloader warning is not raised?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good idea. I added a test test_suggested_max_num_workers_not_triggering_torch_warning

f"The dataloader, {name}, does not have many workers which may be a bottleneck."
" Consider increasing the value of the `num_workers` argument`"
f" (try {num_cpus} which is the number of cpus on this machine)"
" in the `DataLoader` init to improve performance.",
f"The '{name}' does not have many workers which may be a bottleneck. Consider increasing the value of the"
f" `num_workers` argument` to `num_workers={upper_bound}` in the `DataLoader` to improve performance.",
category=PossibleUserWarning,
)

Expand Down Expand Up @@ -507,9 +504,10 @@ def _process_dataloader(

# check the workers
_worker_check(
dataloader,
isinstance(strategy, DDPStrategy) and strategy._start_method == "spawn",
f"{stage.dataloader_prefix}_dataloader",
trainer=trainer,
using_spawn=isinstance(strategy, DDPStrategy) and strategy._start_method == "spawn",
dataloader=dataloader,
name=f"{stage.dataloader_prefix}_dataloader",
)

# add worker_init_fn for correct seeding in worker processes
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from lightning.fabric.utilities import LightningEnum # noqa: F401
from lightning.fabric.utilities import move_data_to_device # noqa: F401
from lightning.fabric.utilities import suggested_max_num_workers # noqa: F401
from lightning.pytorch.utilities.combined_loader import CombinedLoader # noqa: F401
from lightning.pytorch.utilities.enums import GradClipAlgorithmType # noqa: F401
from lightning.pytorch.utilities.grads import grad_norm # noqa: F401
Expand Down
65 changes: 65 additions & 0 deletions tests/tests_fabric/utilities/test_data.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import contextlib
import os
import random
from unittest import mock
from unittest.mock import Mock

import numpy as np
import pytest
import torch
from lightning_utilities.test.warning import no_warning_call
from torch import Tensor
from torch.utils.data import BatchSampler, DataLoader, RandomSampler

import lightning.fabric
from lightning.fabric.utilities.data import (
_get_dataloader_init_args_and_kwargs,
_replace_dunder_methods,
Expand All @@ -17,6 +21,7 @@
_WrapAttrTag,
has_iterable_dataset,
has_len,
suggested_max_num_workers,
)
from lightning.fabric.utilities.exceptions import MisconfigurationException
from tests_fabric.helpers.models import RandomDataset, RandomIterableDataset
Expand Down Expand Up @@ -575,3 +580,63 @@ def test_set_sampler_epoch():
_set_sampler_epoch(dataloader, 55)
dataloader.sampler.set_epoch.assert_called_once_with(55)
dataloader.batch_sampler.sampler.set_epoch.assert_called_once_with(55)


@pytest.mark.parametrize(
("cpu_count", "local_world_size", "expected"),
[
(0, 1, 1),
(1, 1, 1),
(2, 1, 2),
(1, 2, 1),
(1, 2, 1),
(2, 2, 1),
(3, 2, 1),
(4, 2, 2),
(4, 3, 1),
(4, 1, 4),
],
)
@pytest.mark.parametrize(
"affinity",
[
False,
pytest.param(
True,
marks=pytest.mark.skipif(
not hasattr(os, "sched_getaffinity"), reason="OS does not support restricting CPU cores"
),
),
],
)
@mock.patch("lightning.fabric.utilities.data.os.cpu_count")
def test_suggested_max_num_workers(cpu_count_mock, affinity, cpu_count, local_world_size, expected, monkeypatch):
if affinity:
monkeypatch.setattr(lightning.fabric.utilities.data.os, "sched_getaffinity", lambda _: list(range(cpu_count)))
else:
monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False)
cpu_count_mock.return_value = cpu_count

assert suggested_max_num_workers(local_world_size) == expected


@pytest.mark.parametrize("invalid", [-1, 0])
def test_suggested_max_num_workers_input_validation(invalid):
with pytest.raises(ValueError, match="should be >= 1"):
suggested_max_num_workers(invalid)


@pytest.mark.parametrize("cpu_count", [1, 2, 3])
@pytest.mark.parametrize("local_world_size", [1, 2, 3])
def test_suggested_max_num_workers_not_triggering_torch_warning(local_world_size, cpu_count, monkeypatch):
"""Test that our suggestion for num workers doesn't trigger a warning in the DataLoader for too many workers."""
monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False)
monkeypatch.delattr(torch.utils.data.dataloader.os, "sched_getaffinity", raising=False)
monkeypatch.setattr(lightning.fabric.utilities.data.os, "cpu_count", lambda: cpu_count)
monkeypatch.setattr(torch.utils.data.dataloader.os, "cpu_count", lambda: cpu_count)

# The dataloader runs a check in `DataLoader.check_worker_number_rationality`
with pytest.warns(UserWarning, match="This DataLoader will create"):
DataLoader(range(2), num_workers=(cpu_count + 1))
with no_warning_call():
DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size))
37 changes: 37 additions & 0 deletions tests/tests_pytorch/trainer/connectors/test_data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
from io import StringIO
from re import escape
from typing import Sized
from unittest import mock
from unittest.mock import Mock

import pytest
from lightning_utilities.test.warning import no_warning_call
from torch import Tensor
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler

import lightning.fabric
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import Trainer
Expand All @@ -30,6 +32,7 @@
_check_dataloader_iterable,
_DataHookSelector,
_DataLoaderSource,
_worker_check,
warning_cache,
)
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
Expand Down Expand Up @@ -146,6 +149,40 @@ def test_dataloader_warnings(tmpdir, num_workers):
trainer.fit(TestSpawnBoringModel(num_workers))


@pytest.mark.parametrize(
("num_devices", "num_workers", "cpu_count", "expected_warning"),
[
(1, 0, 1, False),
(8, 0, 1, False),
(8, 0, None, False),
(1, 1, None, False),
(1, 2, 2, False),
(1, 1, 8, True),
(1, 2, 8, True),
(1, 3, 8, False),
(4, 1, 8, True),
(4, 2, 8, False),
(8, 2, 8, False),
],
)
@mock.patch("lightning.fabric.utilities.data.os.cpu_count")
def test_worker_check(cpu_count_mock, num_devices, num_workers, cpu_count, expected_warning, monkeypatch):
monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False)
trainer = Mock(spec=Trainer)
dataloader = Mock(spec=DataLoader)
trainer.num_devices = num_devices
dataloader.num_workers = num_workers
cpu_count_mock.return_value = cpu_count

if expected_warning:
ctx = pytest.warns(UserWarning, match="Consider increasing the value of the `num_workers` argument`")
else:
ctx = no_warning_call(UserWarning)

with ctx:
_worker_check(trainer, using_spawn=False, dataloader=dataloader, name="train_dataloader")


def test_update_dataloader_raises():
with pytest.raises(ValueError, match="needs to subclass `torch.utils.data.DataLoader"):
_update_dataloader(object(), object(), mode="fit")
Expand Down
13 changes: 5 additions & 8 deletions tests/tests_pytorch/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def test_warning_on_zero_len_dataloader():
@RunIf(skip_windows=True)
@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"])
@pytest.mark.parametrize("stage", ["train", "test", "val"])
@patch("lightning.pytorch.trainer.connectors.data_connector.multiprocessing.cpu_count", return_value=4)
@patch("lightning.fabric.utilities.data._num_cpus_available", return_value=4)
def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
"""Test that error is raised if dataloader with only a few workers is used."""
model = BoringModel()
Expand All @@ -545,10 +545,7 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2)

with pytest.warns(
UserWarning,
match=f"The dataloader, {stage}_dataloader, does not have many workers",
):
with pytest.warns(UserWarning, match=f"The '{stage}_dataloader' does not have many workers"):
if stage == "test":
if ckpt_path in ("specific", "best"):
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
Expand All @@ -561,9 +558,9 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
@RunIf(skip_windows=True)
@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"])
@pytest.mark.parametrize("stage", ["train", "test", "val"])
@patch("lightning.pytorch.trainer.connectors.data_connector.multiprocessing.cpu_count", return_value=4)
@patch("lightning.fabric.utilities.data._num_cpus_available", return_value=4)
def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
"""Test that error is raised if dataloader with only a few workers is used."""
"""Test that a warning is emitted if the dataloader only has a few workers."""

class CustomModel(MultiEvalDataLoaderModel):
def training_step(self, batch, batch_idx):
Expand All @@ -584,7 +581,7 @@ def training_step(self, batch, batch_idx):

with pytest.warns(
UserWarning,
match=f"The dataloader, {stage}_dataloader, does not have many workers",
match=f"The '{stage}_dataloader' does not have many workers",
):
if stage == "test":
if ckpt_path in ("specific", "best"):
Expand Down