diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index aa8c23a2d2132..8ffc4721a9f9f 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -11,7 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Made saving non-distributed checkpoints fully atomic ([#20011](https://github.com/Lightning-AI/pytorch-lightning/pull/20011)) -- +- Added a flag `verbose` to the `seed_everything()` function ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108)) + ### Changed diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index b9032fe7a9d93..0ff5b04b30b0a 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -909,7 +909,7 @@ def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None) -> No logger.log_metrics(metrics=metrics, step=step) @staticmethod - def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) -> int: + def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None, verbose: bool = True) -> int: r"""Helper function to seed everything without explicitly importing Lightning. See :func:`~lightning.fabric.utilities.seed.seed_everything` for more details. @@ -919,7 +919,7 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) # Lightning sets `workers=False` by default to avoid breaking reproducibility, but since this is a new # release, we can afford to do it. workers = True - return seed_everything(seed=seed, workers=workers) + return seed_everything(seed=seed, workers=workers, verbose=verbose) def _wrap_and_launch(self, to_run: Callable, *args: Any, **kwargs: Any) -> Any: self._launched = True diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index c389012d98fa1..a2d627828a77e 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -17,7 +17,7 @@ min_seed_value = 0 -def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: +def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose: bool = True) -> int: r"""Function that sets the seed for pseudo-random number generators in: torch, numpy, and Python's random module. In addition, sets the following environment variables: @@ -32,6 +32,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: Trainer with a ``worker_init_fn``. If the user already provides such a function for their dataloaders, setting this argument will have no influence. See also: :func:`~lightning.fabric.utilities.seed.pl_worker_init_function`. + verbose: Whether to print a message on each rank with the seed being set. """ if seed is None: @@ -52,7 +53,9 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") seed = 0 - log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank())) + if verbose: + log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank())) + os.environ["PL_GLOBAL_SEED"] = str(seed) random.seed(seed) if _NUMPY_AVAILABLE: @@ -76,7 +79,7 @@ def reset_seed() -> None: if seed is None: return workers = os.environ.get("PL_SEED_WORKERS", "0") - seed_everything(int(seed), workers=bool(int(workers))) + seed_everything(int(seed), workers=bool(int(workers)), verbose=False) def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 7c6fdde6e04ee..b95de03cfa8c9 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -13,7 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `dump_stats` flag to `AdvancedProfiler` ([#19703](https://github.com/Lightning-AI/pytorch-lightning/issues/19703)) -- +- Added a flag `verbose` to the `seed_everything()` function ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108)) + ### Changed @@ -41,7 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Avoid LightningCLI saving hyperparameters with `class_path` and `init_args` since this would be a breaking change ([#20068](https://github.com/Lightning-AI/pytorch-lightning/pull/20068)) -- +- Fixed an issue that would cause too many printouts of the seed info when using `seed_everything()` ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108)) diff --git a/tests/tests_fabric/utilities/test_seed.py b/tests/tests_fabric/utilities/test_seed.py index bb1d3583f56a6..be2ecba3294b1 100644 --- a/tests/tests_fabric/utilities/test_seed.py +++ b/tests/tests_fabric/utilities/test_seed.py @@ -3,7 +3,6 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric.utilities import numpy import pytest import torch @@ -11,6 +10,7 @@ _collect_rng_states, _set_rng_states, pl_worker_init_function, + reset_seed, seed_everything, ) @@ -18,7 +18,7 @@ @mock.patch.dict(os.environ, clear=True) def test_default_seed(): """Test that the default seed is 0 when no seed provided and no environment variable set.""" - assert lightning.fabric.utilities.seed.seed_everything() == 0 + assert seed_everything() == 0 assert os.environ["PL_GLOBAL_SEED"] == "0" @@ -26,11 +26,11 @@ def test_default_seed(): def test_seed_stays_same_with_multiple_seed_everything_calls(): """Ensure that after the initial seed everything, the seed stays the same for the same run.""" with pytest.warns(UserWarning, match="No seed found"): - lightning.fabric.utilities.seed.seed_everything() + seed_everything() initial_seed = os.environ.get("PL_GLOBAL_SEED") with pytest.warns(None) as record: - lightning.fabric.utilities.seed.seed_everything() + seed_everything() assert not record # does not warn seed = os.environ.get("PL_GLOBAL_SEED") @@ -40,14 +40,14 @@ def test_seed_stays_same_with_multiple_seed_everything_calls(): @mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True) def test_correct_seed_with_environment_variable(): """Ensure that the PL_GLOBAL_SEED environment is read.""" - assert lightning.fabric.utilities.seed.seed_everything() == 2020 + assert seed_everything() == 2020 @mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True) def test_invalid_seed(): """Ensure that we still fix the seed even if an invalid seed is given.""" with pytest.warns(UserWarning, match="Invalid seed found"): - seed = lightning.fabric.utilities.seed.seed_everything() + seed = seed_everything() assert seed == 0 @@ -56,7 +56,7 @@ def test_invalid_seed(): def test_out_of_bounds_seed(seed): """Ensure that we still fix the seed even if an out-of-bounds seed is given.""" with pytest.warns(UserWarning, match="is not in bounds"): - actual = lightning.fabric.utilities.seed.seed_everything(seed) + actual = seed_everything(seed) assert actual == 0 @@ -64,7 +64,7 @@ def test_reset_seed_no_op(): """Test that the reset_seed function is a no-op when seed_everything() was not used.""" assert "PL_GLOBAL_SEED" not in os.environ seed_before = torch.initial_seed() - lightning.fabric.utilities.seed.reset_seed() + reset_seed() assert torch.initial_seed() == seed_before assert "PL_GLOBAL_SEED" not in os.environ @@ -75,18 +75,26 @@ def test_reset_seed_everything(workers): assert "PL_GLOBAL_SEED" not in os.environ assert "PL_SEED_WORKERS" not in os.environ - lightning.fabric.utilities.seed.seed_everything(123, workers) + seed_everything(123, workers) before = torch.rand(1) assert os.environ["PL_GLOBAL_SEED"] == "123" assert os.environ["PL_SEED_WORKERS"] == str(int(workers)) - lightning.fabric.utilities.seed.reset_seed() + reset_seed() after = torch.rand(1) assert os.environ["PL_GLOBAL_SEED"] == "123" assert os.environ["PL_SEED_WORKERS"] == str(int(workers)) assert torch.allclose(before, after) +def test_reset_seed_non_verbose(caplog): + seed_everything(123) + assert len(caplog.records) == 1 + caplog.clear() + reset_seed() # should call `seed_everything(..., verbose=False)` + assert len(caplog.records) == 0 + + def test_backward_compatibility_rng_states_dict(): """Test that an older rng_states_dict without the "torch.cuda" key does not crash.""" states = _collect_rng_states()