Skip to content

Commit

Permalink
Avoid printing the seed info message multiple times (Lightning-AI#20108)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored and ammyk9 committed Aug 6, 2024
1 parent 94efec1 commit 08fe9ee
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 18 deletions.
3 changes: 2 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Made saving non-distributed checkpoints fully atomic ([#20011](https://github.com/Lightning-AI/pytorch-lightning/pull/20011))

-
- Added a flag `verbose` to the `seed_everything()` function ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))


### Changed

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None) -> No
logger.log_metrics(metrics=metrics, step=step)

@staticmethod
def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) -> int:
def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None, verbose: bool = True) -> int:
r"""Helper function to seed everything without explicitly importing Lightning.
See :func:`~lightning.fabric.utilities.seed.seed_everything` for more details.
Expand All @@ -919,7 +919,7 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None)
# Lightning sets `workers=False` by default to avoid breaking reproducibility, but since this is a new
# release, we can afford to do it.
workers = True
return seed_everything(seed=seed, workers=workers)
return seed_everything(seed=seed, workers=workers, verbose=verbose)

def _wrap_and_launch(self, to_run: Callable, *args: Any, **kwargs: Any) -> Any:
self._launched = True
Expand Down
9 changes: 6 additions & 3 deletions src/lightning/fabric/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
min_seed_value = 0


def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose: bool = True) -> int:
r"""Function that sets the seed for pseudo-random number generators in: torch, numpy, and Python's random module.
In addition, sets the following environment variables:
Expand All @@ -32,6 +32,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
Trainer with a ``worker_init_fn``. If the user already provides such a function
for their dataloaders, setting this argument will have no influence. See also:
:func:`~lightning.fabric.utilities.seed.pl_worker_init_function`.
verbose: Whether to print a message on each rank with the seed being set.
"""
if seed is None:
Expand All @@ -52,7 +53,9 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
seed = 0

log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank()))
if verbose:
log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank()))

os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
if _NUMPY_AVAILABLE:
Expand All @@ -76,7 +79,7 @@ def reset_seed() -> None:
if seed is None:
return
workers = os.environ.get("PL_SEED_WORKERS", "0")
seed_everything(int(seed), workers=bool(int(workers)))
seed_everything(int(seed), workers=bool(int(workers)), verbose=False)


def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `dump_stats` flag to `AdvancedProfiler` ([#19703](https://github.com/Lightning-AI/pytorch-lightning/issues/19703))

-
- Added a flag `verbose` to the `seed_everything()` function ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))


### Changed

Expand Down Expand Up @@ -41,7 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Avoid LightningCLI saving hyperparameters with `class_path` and `init_args` since this would be a breaking change ([#20068](https://github.com/Lightning-AI/pytorch-lightning/pull/20068))

-
- Fixed an issue that would cause too many printouts of the seed info when using `seed_everything()` ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))



Expand Down
28 changes: 18 additions & 10 deletions tests/tests_fabric/utilities/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,34 @@
from unittest import mock
from unittest.mock import Mock

import lightning.fabric.utilities
import numpy
import pytest
import torch
from lightning.fabric.utilities.seed import (
_collect_rng_states,
_set_rng_states,
pl_worker_init_function,
reset_seed,
seed_everything,
)


@mock.patch.dict(os.environ, clear=True)
def test_default_seed():
"""Test that the default seed is 0 when no seed provided and no environment variable set."""
assert lightning.fabric.utilities.seed.seed_everything() == 0
assert seed_everything() == 0
assert os.environ["PL_GLOBAL_SEED"] == "0"


@mock.patch.dict(os.environ, {}, clear=True)
def test_seed_stays_same_with_multiple_seed_everything_calls():
"""Ensure that after the initial seed everything, the seed stays the same for the same run."""
with pytest.warns(UserWarning, match="No seed found"):
lightning.fabric.utilities.seed.seed_everything()
seed_everything()
initial_seed = os.environ.get("PL_GLOBAL_SEED")

with pytest.warns(None) as record:
lightning.fabric.utilities.seed.seed_everything()
seed_everything()
assert not record # does not warn
seed = os.environ.get("PL_GLOBAL_SEED")

Expand All @@ -40,14 +40,14 @@ def test_seed_stays_same_with_multiple_seed_everything_calls():
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True)
def test_correct_seed_with_environment_variable():
"""Ensure that the PL_GLOBAL_SEED environment is read."""
assert lightning.fabric.utilities.seed.seed_everything() == 2020
assert seed_everything() == 2020


@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True)
def test_invalid_seed():
"""Ensure that we still fix the seed even if an invalid seed is given."""
with pytest.warns(UserWarning, match="Invalid seed found"):
seed = lightning.fabric.utilities.seed.seed_everything()
seed = seed_everything()
assert seed == 0


Expand All @@ -56,15 +56,15 @@ def test_invalid_seed():
def test_out_of_bounds_seed(seed):
"""Ensure that we still fix the seed even if an out-of-bounds seed is given."""
with pytest.warns(UserWarning, match="is not in bounds"):
actual = lightning.fabric.utilities.seed.seed_everything(seed)
actual = seed_everything(seed)
assert actual == 0


def test_reset_seed_no_op():
"""Test that the reset_seed function is a no-op when seed_everything() was not used."""
assert "PL_GLOBAL_SEED" not in os.environ
seed_before = torch.initial_seed()
lightning.fabric.utilities.seed.reset_seed()
reset_seed()
assert torch.initial_seed() == seed_before
assert "PL_GLOBAL_SEED" not in os.environ

Expand All @@ -75,18 +75,26 @@ def test_reset_seed_everything(workers):
assert "PL_GLOBAL_SEED" not in os.environ
assert "PL_SEED_WORKERS" not in os.environ

lightning.fabric.utilities.seed.seed_everything(123, workers)
seed_everything(123, workers)
before = torch.rand(1)
assert os.environ["PL_GLOBAL_SEED"] == "123"
assert os.environ["PL_SEED_WORKERS"] == str(int(workers))

lightning.fabric.utilities.seed.reset_seed()
reset_seed()
after = torch.rand(1)
assert os.environ["PL_GLOBAL_SEED"] == "123"
assert os.environ["PL_SEED_WORKERS"] == str(int(workers))
assert torch.allclose(before, after)


def test_reset_seed_non_verbose(caplog):
seed_everything(123)
assert len(caplog.records) == 1
caplog.clear()
reset_seed() # should call `seed_everything(..., verbose=False)`
assert len(caplog.records) == 0


def test_backward_compatibility_rng_states_dict():
"""Test that an older rng_states_dict without the "torch.cuda" key does not crash."""
states = _collect_rng_states()
Expand Down

0 comments on commit 08fe9ee

Please sign in to comment.