Skip to content

Commit

Permalink
minimize_stateless: avoid reusing initialization seed
Browse files Browse the repository at this point in the history
Discovered by running tests with `jax_enable_key_reuse_checks=True`.

PiperOrigin-RevId: 614737610
  • Loading branch information
vanderplas authored and tensorflower-gardener committed Mar 11, 2024
1 parent 5e568e1 commit f32c8d4
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tensorflow_probability/python/math/minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def run_jitted_minimize():
seed_is_none = seed is None
if not seed_is_none:
seed = samplers.sanitize_seed(seed, salt='minimize')
init_seed, seed = samplers.split_seed(seed, n=2)
else:
init_seed = None

if not return_full_length_trace:
# Augment trace to record convergence info, so we can truncate it later.
Expand All @@ -153,7 +156,7 @@ def run_jitted_minimize():
initial_optimizer_state) = optimizer_step_fn(
parameters=initial_parameters,
optimizer_state=initial_optimizer_state,
seed=seed)
seed=init_seed)

initial_convergence_criterion_state = ()
if convergence_criterion is not None:
Expand Down

0 comments on commit f32c8d4

Please sign in to comment.