Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seeding not working for single chains #5378

Closed
ricardoV94 opened this issue Jan 20, 2022 · 4 comments · Fixed by #5377
Closed

Seeding not working for single chains #5378

ricardoV94 opened this issue Jan 20, 2022 · 4 comments · Fixed by #5377

Comments

@ricardoV94
Copy link
Member

Seeding seems to not be respected when using a single chain:

import numpy as np
import pymc as pm

with pm.Model(rng_seeder=3):
    x = pm.Normal("x")

    tr1 = pm.sample(
        chains=1,
        random_seed=1,
        tune=0,
        draws=10,
        return_inferencedata=False,
        compute_convergence_checks=False,
    )
    tr2 = pm.sample(
        chains=1,
        random_seed=1,
        tune=0,
        draws=10,
        return_inferencedata=False,
        compute_convergence_checks=False,
    )

    assert np.allclose(tr1["x"], tr2["x"])  # Fails

There is also a similar failure when passing multiple seeds, but working with a single core (i.e., sequential sampling)

#5377 Includes a test which covers the failing cases

@michaelosthege
Copy link
Member

So the random state of the 2nd sampling is not identical.
Printing np.random.get_state() gives a bit more information, but I didn't find any explanation of these items.

Maybe the chains=1, cores=1 goes through a different control flow compared to chains=2, cores=1?
Adding a few _log.debug() lines could help.

import numpy as np
import pymc as pm

class BlablaSampler(pm.Metropolis):
    def step(self, *args, **kwargs):
        _, rstate, pos, has_gauss, cached_gaussian = np.random.get_state()
        draw, stats = super().step(*args, **kwargs)
        print(f"""
args             {args}
kwargs           {kwargs}
rstate           {hash(rstate.tobytes())}
pos              {pos}
has_gauss        {has_gauss}
cached_gaussian  {cached_gaussian}
↓↓↓↓↓
draw             {draw}
stats            {stats}
        """)
        return draw, stats

with pm.Model(rng_seeder=3) as pmodel:
    x = pm.Normal("x", initval="prior")
    
    common = dict(
        chains=1,
        cores=1,
        random_seed=1,
        tune=0,
        draws=3,
        return_inferencedata=False,
        compute_convergence_checks=False,
    )

    tr1 = pm.sample(
        **common,
        step=BlablaSampler(),
    )
    tr2 = pm.sample(
        **common,
        step=BlablaSampler(),
    )

    assert np.allclose(tr1["x"], tr2["x"])  # Fails
Output
Only 3 samples in chain.
Sequential sampling (1 chains in 1 job)
BlablaSampler: [x]

100.00% [3/3 00:00<00:00 Sampling chain 0, 0 divergences]


args             ({'x': array(-0.88681976)},)
kwargs           {}
rstate           -2368939229662006587
pos              99
has_gauss        0
cached_gaussian  0.0
↓↓↓↓↓
draw             {'x': array(-0.88681976)}
stats            [{'tune': False, 'scaling': array([1.]), 'accept': 0.3279256920370864, 'accepted': False}]
        

args             ({'x': array(-0.88681976)},)
kwargs           {}
rstate           -2368939229662006587
pos              105
has_gauss        1
cached_gaussian  0.05062323074151636
↓↓↓↓↓
draw             {'x': array(-0.83619653)}
stats            [{'tune': False, 'scaling': array([1.]), 'accept': 1.044577320525613, 'accepted': True}]
        

args             ({'x': array(-0.83619653)},)
kwargs           {}
rstate           -2368939229662006587
pos              107
has_gauss        0
cached_gaussian  0.0
↓↓↓↓↓
draw             {'x': array(-0.13112415)}
stats            [{'tune': False, 'scaling': array([1.]), 'accept': 1.4063751057573721, 'accepted': True}]
        

Sampling 1 chain for 0 tune and 3 draw iterations (0 + 3 draws total) took 0 seconds.
Only 3 samples in chain.
Sequential sampling (1 chains in 1 job)
BlablaSampler: [x]

100.00% [3/3 00:00<00:00 Sampling chain 0, 0 divergences]


args             ({'x': array(-0.88681976)},)
kwargs           {}
rstate           -2368939229662006587
pos              113
has_gauss        1
cached_gaussian  -1.679941954932756
↓↓↓↓↓
draw             {'x': array(-0.88681976)}
stats            [{'tune': False, 'scaling': array([1.]), 'accept': 0.054973270188106695, 'accepted': False}]
        

args             ({'x': array(-0.88681976)},)
kwargs           {}
rstate           -2368939229662006587
pos              115
has_gauss        0
cached_gaussian  0.0
↓↓↓↓↓
draw             {'x': array(-0.6937018)}
stats            [{'tune': False, 'scaling': array([1.]), 'accept': 1.1648747282621459, 'accepted': True}]
        

args             ({'x': array(-0.6937018)},)
kwargs           {}
rstate           -2368939229662006587
pos              125
has_gauss        1
cached_gaussian  -1.5530074084714907
↓↓↓↓↓
draw             {'x': array(-0.6937018)}
stats            [{'tune': False, 'scaling': array([1.]), 'accept': 0.10195333843669846, 'accepted': False}]
        

Sampling 1 chain for 0 tune and 3 draw iterations (0 + 3 draws total) took 0 seconds.

@ricardoV94
Copy link
Member Author

I don't think the random seeding is taking any effect

@michaelosthege
Copy link
Member

michaelosthege commented Jan 25, 2022

Well it's working for multiple chains, but maybe that's because of a process forking or a difference between different control flows that "protects" the mother-seed.

I looked at that control flow yesterday and with this discussion in mind we might want to solve this as part of a bigger refactor.

@ricardoV94 ricardoV94 pinned this issue Feb 15, 2022
@ricardoV94 ricardoV94 modified the milestones: v4.0.0b3, v4.0.0 Feb 15, 2022
@OriolAbril
Copy link
Member

I don't know how seeding works, doesn't work nor is expected to work on v4. Just sharing this as a potential test to make sure seeding is reproducible yet constrained to pymc: https://discourse.pymc.io/t/weired-random-number-generation-pattern-after-training-a-pymc3-model/8875

@ricardoV94 ricardoV94 modified the milestones: v4.0.0, v4.1.0 Apr 1, 2022
@ricardoV94 ricardoV94 changed the title Seeding not working Seeding not working for single chains Apr 1, 2022
@ricardoV94 ricardoV94 unpinned this issue May 20, 2022
@michaelosthege michaelosthege modified the milestones: v4.1.0, v4.0.0 Jul 2, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants