diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 36dda97917e03..5ab86678ad015 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -184,12 +184,16 @@ Example:: from pytorch_lightning import Trainer, seed_everything - seed_everything(42) + seed_everything(42, workers=True) # sets seeds for numpy, torch, python.random and PYTHONHASHSEED. model = Model() trainer = Trainer(deterministic=True) +By setting ``workers=True`` in :func:`~pytorch_lightning.utilities.seed.seed_everything`, Lightning derives +unique seeds across all dataloader workers and processes for :mod:`torch`, :mod:`numpy` and stdlib +:mod:`random` number generators. When turned on, it ensures that e.g. data augmentations are not repeated across workers. + ------- Trainer flags diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 8fd39dfe94a89..a361f6e6203c2 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -16,6 +16,7 @@ import os from abc import ABC from copy import deepcopy +from functools import partial from typing import Iterable, List, Optional, Tuple, Union from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler @@ -31,6 +32,7 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.seed import pl_worker_init_function class TrainerDataLoadingMixin(ABC): @@ -101,6 +103,10 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: f' in the `DataLoader` init to improve performance.' ) + def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None: + if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None: + dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank) + def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: # don't do anything if it's not a dataloader @@ -234,6 +240,9 @@ def reset_train_dataloader(self, model: LightningModule) -> None: # check the workers recursively apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader') + # add worker_init_fn for correct seeding in worker processes + apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn) + # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode) @@ -332,6 +341,9 @@ def _reset_eval_dataloader( # add samplers dataloaders = [self.auto_add_sampler(dl, shuffle=False) for dl in dataloaders if dl is not None] + # add worker_init_fn for correct seeding in worker processes + apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn) + loader_num_batches = [] # determine number of batches diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index c3da02b7d2cdb..b7eaba72c1b02 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -21,22 +21,29 @@ import numpy as np import torch -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, rank_zero_warn +from pytorch_lightning.utilities.distributed import rank_zero_only log = logging.getLogger(__name__) -def seed_everything(seed: Optional[int] = None) -> int: +def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: """ Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random - In addition, sets the env variable `PL_GLOBAL_SEED` which will be passed to - spawned subprocesses (e.g. ddp_spawn backend). + In addition, sets the following environment variables: + + - `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend). + - `PL_SEED_WORKERS`: (optional) is set to 1 if ```workers=True``. Args: seed: the integer value seed for global random state in Lightning. If `None`, will read seed from `PL_GLOBAL_SEED` env variable or select it randomly. + workers: if set to ``True``, will properly configure all dataloaders passed to the + Trainer with a ``worker_init_fn``. If the user already provides such a function + for their dataloaders, setting this argument will have no influence. See also: + :func:`~pytorch_lightning.utilities.seed.pl_worker_init_function`. """ max_seed_value = np.iinfo(np.uint32).max min_seed_value = np.iinfo(np.uint32).min @@ -61,8 +68,36 @@ def seed_everything(seed: Optional[int] = None) -> int: np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) + + os.environ["PL_SEED_WORKERS"] = f"{int(workers)}" + return seed def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int: return random.randint(min_seed_value, max_seed_value) + + +def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover + """ + The worker_init_fn that Lightning automatically adds to your dataloader if you previously set + set the seed with ``seed_everything(seed, workers=True)``. + See also the PyTorch documentation on + `randomness in DataLoaders `_. + """ + # implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 + global_rank = rank if rank is not None else rank_zero_only.rank + process_seed = torch.initial_seed() + # back out the base seed so we can use all the bits + base_seed = process_seed - worker_id + ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) + # use 128 bits (4 x 32-bit words) + np.random.seed(ss.generate_state(4)) + # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module + torch_ss, stdlib_ss = ss.spawn(2) + # PyTorch 1.7 and above takes a 64-bit seed + dtype = np.uint64 if _TORCH_GREATER_EQUAL_1_7 else np.uint32 + torch.manual_seed(torch_ss.generate_state(1, dtype=dtype)[0]) + # use 128 bits expressed as an integer + stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() + random.seed(stdlib_seed) diff --git a/requirements.txt b/requirements.txt index 3faed306a488a..3438b1ea2189b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # the default package dependencies -numpy>=1.16.6 +numpy>=1.17.2 torch>=1.4 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index 382311c107958..831fc474336b6 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import pytest from torch.utils.data import DataLoader from torch.utils.data.sampler import BatchSampler, SequentialSampler @@ -72,7 +71,7 @@ def test_dataloader(self): return [self.create_dataset()] * self._numbers_test_dataloaders -def check_replace_distrubuted_sampler(tmpdir, save_preds_on_dl_idx, accelerator, gpus, num_dl_idx, mode): +def check_replace_distributed_sampler(tmpdir, save_preds_on_dl_idx, accelerator, gpus, num_dl_idx, mode): num_processes = 2 limit_test_batches = 2 trainer_args = { @@ -100,8 +99,8 @@ def check_replace_distrubuted_sampler(tmpdir, save_preds_on_dl_idx, accelerator, @RunIf(min_gpus=2, special=True) @pytest.mark.parametrize("mode", [1, 2]) -def test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode): - check_replace_distrubuted_sampler(tmpdir, True, "ddp", 2, 2, mode) +def test_replace_distributed_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode): + check_replace_distributed_sampler(tmpdir, True, "ddp", 2, 2, mode) @pytest.mark.parametrize("num_workers", [0, 1]) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 89793ba71ed14..de3e83b54001a 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -13,12 +13,13 @@ # limitations under the License. import os from unittest import mock -from unittest.mock import patch +from unittest.mock import Mock, patch +import numpy import pytest import torch from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import IterableDataset, Subset +from torch.utils.data.dataset import Dataset, IterableDataset, Subset from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import SequentialSampler @@ -635,6 +636,109 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) +class NumpyRandomDataset(Dataset): + # this datset uses numpy instead of torch to produce random numbers + size = 16 + + def __getitem__(self, index): + return numpy.random.randint(0, 100, 3) + + def __len__(self): + return self.size + + +def _user_worker_init_fn(_): + pass + + +def test_missing_worker_init_fn(): + """ Test that naive worker seed initialization leads to undesired random state in subprocesses. """ + dataset = NumpyRandomDataset() + + seed_everything(0) + dataloader = DataLoader(dataset, batch_size=2, num_workers=2, shuffle=False) + batches0 = torch.cat([batch for batch in dataloader]) + + seed_everything(0) + dataloader = DataLoader(dataset, batch_size=2, num_workers=2, shuffle=False) + batches1 = torch.cat([batch for batch in dataloader]) + + is_duplicated = len(torch.unique(batches1, dim=0)) < len(dataset) + is_deterministic = torch.eq(batches0, batches1).all() + + # depending on the OS, we either have + # 1) the same seed in all worker proceses, producing duplicate samples / augmentations, or + # 2) different seeds in each worker process, but they are not derived from the seed of the main process + assert not is_deterministic or is_duplicated + + +def test_auto_add_worker_init_fn(): + """ Test Trainer adds a default worker_init_fn to the dataloader when seed_everything() is used. """ + dataset = Mock() + dataloader = DataLoader(dataset) + trainer = Trainer() + + # without pl.seed_everything() + trainer.auto_add_worker_init_fn(dataloader) + assert dataloader.worker_init_fn is None + + # with forcefully avoiding it + seed_everything(0, workers=False) + trainer.auto_add_worker_init_fn(dataloader) + assert dataloader.worker_init_fn is None + + # when user already has a worker_init_fn + user_function = _user_worker_init_fn + dataloader.worker_init_fn = user_function + trainer.auto_add_worker_init_fn(dataloader) + assert dataloader.worker_init_fn is user_function + dataloader.worker_init_fn = None + + # main use case + seed_everything(0, workers=True) + trainer.auto_add_worker_init_fn(dataloader) + assert dataloader.worker_init_fn is not None + + +class MultiProcessModel(BoringModel): + + def __init__(self): + super().__init__() + self.batches_seen = [] + + def training_step(self, batch, batch_idx): + self.batches_seen.append(batch) + + def training_epoch_end(self, outputs): + world_size = 2 + num_samples = NumpyRandomDataset.size + all_batches = torch.cat(self.batches_seen) + all_batches = self.all_gather(all_batches) + assert all_batches.shape[0] == world_size + all_batches = all_batches.view(-1, 3) + assert len(torch.unique(all_batches, dim=0)) == num_samples + + +@RunIf(min_gpus=2) +def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch): + """ Test that the lightning worker_init_fn takes care of dataloaders in multi-gpu/multi-node training. """ + dataset = NumpyRandomDataset() + num_workers = 2 + batch_size = 2 + + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) + seed_everything(0, workers=True) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + gpus=2, + accelerator="ddp_spawn", + ) + model = MultiProcessModel() + model.val_dataloader = None + trainer.fit(model, train_dataloader=dataloader) + + def test_warning_with_iterable_dataset_and_len(tmpdir): """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """ model = BoringModel()