Skip to content

Commit

Permalink
Auto-set DataLoader.worker_init_fn with seed_everything (#6960)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
  • Loading branch information
3 people authored Apr 19, 2021
1 parent d1529c2 commit 60c1c8f
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 12 deletions.
6 changes: 5 additions & 1 deletion docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,16 @@ Example::

from pytorch_lightning import Trainer, seed_everything

seed_everything(42)
seed_everything(42, workers=True)
# sets seeds for numpy, torch, python.random and PYTHONHASHSEED.
model = Model()
trainer = Trainer(deterministic=True)


By setting ``workers=True`` in :func:`~pytorch_lightning.utilities.seed.seed_everything`, Lightning derives
unique seeds across all dataloader workers and processes for :mod:`torch`, :mod:`numpy` and stdlib
:mod:`random` number generators. When turned on, it ensures that e.g. data augmentations are not repeated across workers.

-------

Trainer flags
Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
from abc import ABC
from copy import deepcopy
from functools import partial
from typing import Iterable, List, Optional, Tuple, Union

from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
Expand All @@ -31,6 +32,7 @@
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import pl_worker_init_function


class TrainerDataLoadingMixin(ABC):
Expand Down Expand Up @@ -101,6 +103,10 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
f' in the `DataLoader` init to improve performance.'
)

def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None:
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank)

def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:

# don't do anything if it's not a dataloader
Expand Down Expand Up @@ -234,6 +240,9 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
# check the workers recursively
apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader')

# add worker_init_fn for correct seeding in worker processes
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)

# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode)

Expand Down Expand Up @@ -332,6 +341,9 @@ def _reset_eval_dataloader(
# add samplers
dataloaders = [self.auto_add_sampler(dl, shuffle=False) for dl in dataloaders if dl is not None]

# add worker_init_fn for correct seeding in worker processes
apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn)

loader_num_batches = []

# determine number of batches
Expand Down
43 changes: 39 additions & 4 deletions pytorch_lightning/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,29 @@
import numpy as np
import torch

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only

log = logging.getLogger(__name__)


def seed_everything(seed: Optional[int] = None) -> int:
def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
"""
Function that sets seed for pseudo-random number generators in:
pytorch, numpy, python.random
In addition, sets the env variable `PL_GLOBAL_SEED` which will be passed to
spawned subprocesses (e.g. ddp_spawn backend).
In addition, sets the following environment variables:
- `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend).
- `PL_SEED_WORKERS`: (optional) is set to 1 if ```workers=True``.
Args:
seed: the integer value seed for global random state in Lightning.
If `None`, will read seed from `PL_GLOBAL_SEED` env variable
or select it randomly.
workers: if set to ``True``, will properly configure all dataloaders passed to the
Trainer with a ``worker_init_fn``. If the user already provides such a function
for their dataloaders, setting this argument will have no influence. See also:
:func:`~pytorch_lightning.utilities.seed.pl_worker_init_function`.
"""
max_seed_value = np.iinfo(np.uint32).max
min_seed_value = np.iinfo(np.uint32).min
Expand All @@ -61,8 +68,36 @@ def seed_everything(seed: Optional[int] = None) -> int:
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"

return seed


def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int:
return random.randint(min_seed_value, max_seed_value)


def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover
"""
The worker_init_fn that Lightning automatically adds to your dataloader if you previously set
set the seed with ``seed_everything(seed, workers=True)``.
See also the PyTorch documentation on
`randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_.
"""
# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
global_rank = rank if rank is not None else rank_zero_only.rank
process_seed = torch.initial_seed()
# back out the base seed so we can use all the bits
base_seed = process_seed - worker_id
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
# use 128 bits (4 x 32-bit words)
np.random.seed(ss.generate_state(4))
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
torch_ss, stdlib_ss = ss.spawn(2)
# PyTorch 1.7 and above takes a 64-bit seed
dtype = np.uint64 if _TORCH_GREATER_EQUAL_1_7 else np.uint32
torch.manual_seed(torch_ss.generate_state(1, dtype=dtype)[0])
# use 128 bits expressed as an integer
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
random.seed(stdlib_seed)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# the default package dependencies

numpy>=1.16.6
numpy>=1.17.2
torch>=1.4
future>=0.17.1 # required for builtins in setup.py
# pyyaml>=3.13
Expand Down
7 changes: 3 additions & 4 deletions tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from torch.utils.data import DataLoader
from torch.utils.data.sampler import BatchSampler, SequentialSampler
Expand Down Expand Up @@ -72,7 +71,7 @@ def test_dataloader(self):
return [self.create_dataset()] * self._numbers_test_dataloaders


def check_replace_distrubuted_sampler(tmpdir, save_preds_on_dl_idx, accelerator, gpus, num_dl_idx, mode):
def check_replace_distributed_sampler(tmpdir, save_preds_on_dl_idx, accelerator, gpus, num_dl_idx, mode):
num_processes = 2
limit_test_batches = 2
trainer_args = {
Expand Down Expand Up @@ -100,8 +99,8 @@ def check_replace_distrubuted_sampler(tmpdir, save_preds_on_dl_idx, accelerator,

@RunIf(min_gpus=2, special=True)
@pytest.mark.parametrize("mode", [1, 2])
def test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode):
check_replace_distrubuted_sampler(tmpdir, True, "ddp", 2, 2, mode)
def test_replace_distributed_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode):
check_replace_distributed_sampler(tmpdir, True, "ddp", 2, 2, mode)


@pytest.mark.parametrize("num_workers", [0, 1])
Expand Down
108 changes: 106 additions & 2 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.
import os
from unittest import mock
from unittest.mock import patch
from unittest.mock import Mock, patch

import numpy
import pytest
import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import IterableDataset, Subset
from torch.utils.data.dataset import Dataset, IterableDataset, Subset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import SequentialSampler

Expand Down Expand Up @@ -635,6 +636,109 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl)


class NumpyRandomDataset(Dataset):
# this datset uses numpy instead of torch to produce random numbers
size = 16

def __getitem__(self, index):
return numpy.random.randint(0, 100, 3)

def __len__(self):
return self.size


def _user_worker_init_fn(_):
pass


def test_missing_worker_init_fn():
""" Test that naive worker seed initialization leads to undesired random state in subprocesses. """
dataset = NumpyRandomDataset()

seed_everything(0)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, shuffle=False)
batches0 = torch.cat([batch for batch in dataloader])

seed_everything(0)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, shuffle=False)
batches1 = torch.cat([batch for batch in dataloader])

is_duplicated = len(torch.unique(batches1, dim=0)) < len(dataset)
is_deterministic = torch.eq(batches0, batches1).all()

# depending on the OS, we either have
# 1) the same seed in all worker proceses, producing duplicate samples / augmentations, or
# 2) different seeds in each worker process, but they are not derived from the seed of the main process
assert not is_deterministic or is_duplicated


def test_auto_add_worker_init_fn():
""" Test Trainer adds a default worker_init_fn to the dataloader when seed_everything() is used. """
dataset = Mock()
dataloader = DataLoader(dataset)
trainer = Trainer()

# without pl.seed_everything()
trainer.auto_add_worker_init_fn(dataloader)
assert dataloader.worker_init_fn is None

# with forcefully avoiding it
seed_everything(0, workers=False)
trainer.auto_add_worker_init_fn(dataloader)
assert dataloader.worker_init_fn is None

# when user already has a worker_init_fn
user_function = _user_worker_init_fn
dataloader.worker_init_fn = user_function
trainer.auto_add_worker_init_fn(dataloader)
assert dataloader.worker_init_fn is user_function
dataloader.worker_init_fn = None

# main use case
seed_everything(0, workers=True)
trainer.auto_add_worker_init_fn(dataloader)
assert dataloader.worker_init_fn is not None


class MultiProcessModel(BoringModel):

def __init__(self):
super().__init__()
self.batches_seen = []

def training_step(self, batch, batch_idx):
self.batches_seen.append(batch)

def training_epoch_end(self, outputs):
world_size = 2
num_samples = NumpyRandomDataset.size
all_batches = torch.cat(self.batches_seen)
all_batches = self.all_gather(all_batches)
assert all_batches.shape[0] == world_size
all_batches = all_batches.view(-1, 3)
assert len(torch.unique(all_batches, dim=0)) == num_samples


@RunIf(min_gpus=2)
def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch):
""" Test that the lightning worker_init_fn takes care of dataloaders in multi-gpu/multi-node training. """
dataset = NumpyRandomDataset()
num_workers = 2
batch_size = 2

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
seed_everything(0, workers=True)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
gpus=2,
accelerator="ddp_spawn",
)
model = MultiProcessModel()
model.val_dataloader = None
trainer.fit(model, train_dataloader=dataloader)


def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
model = BoringModel()
Expand Down

0 comments on commit 60c1c8f

Please sign in to comment.