diff --git a/jax/experimental/key_reuse/_forwarding.py b/jax/experimental/key_reuse/_forwarding.py index d93d84ee12f5..db5a9dc9359f 100644 --- a/jax/experimental/key_reuse/_forwarding.py +++ b/jax/experimental/key_reuse/_forwarding.py @@ -211,8 +211,11 @@ def _pjit_key_type_signature(eqn, args_consumed): jaxpr = eqn.params['jaxpr'] forwarded_inputs = {i: eqn.invars.index(var) for i, var in enumerate(eqn.invars) if var in eqn.invars[:i]} - return get_jaxpr_type_signature(jaxpr.jaxpr, consumed_inputs=args_consumed, - forwarded_inputs=forwarded_inputs) + sig = get_jaxpr_type_signature(jaxpr.jaxpr) + if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks): + # Double consumption detected: re-trace with context for better errors. + get_jaxpr_type_signature(jaxpr.jaxpr, args_consumed, forwarded_inputs) + return sig key_reuse_signatures_dynamic[pjit.pjit_p] = _pjit_key_type_signature diff --git a/jax/experimental/key_reuse/_simple.py b/jax/experimental/key_reuse/_simple.py index 62a43aa126a0..edc4f2a460db 100644 --- a/jax/experimental/key_reuse/_simple.py +++ b/jax/experimental/key_reuse/_simple.py @@ -183,7 +183,11 @@ def _pjit_key_type_signature(eqn, args_consumed): non_literal_invars = [v for v in eqn.invars if not isinstance(v, core.Literal)] if len(set(non_literal_invars)) != len(non_literal_invars): raise ValueError(f"pjit with duplicate inputs: {eqn.invars=}") - return get_jaxpr_type_signature(jaxpr.jaxpr, consumed_inputs=args_consumed) + sig = get_jaxpr_type_signature(jaxpr.jaxpr) + if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks): + # Double consumption detected: re-trace with context for better errors. + get_jaxpr_type_signature(jaxpr.jaxpr, args_consumed) + return sig key_reuse_signatures_dynamic[pjit.pjit_p] = _pjit_key_type_signature diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 4877d2e5d1a8..5ebecf4d4e19 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -573,6 +573,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase): random_bits_error = "In random_bits, key values .+ are already consumed.*" random_split_error = "In random_split, key values .+ are already consumed.*" generic_error = ".*key values .+ are already consumed.*" + pjit_error = "In pjit, key values a are already consumed." def check_key_reuse(self, f, *args): if self.use_forwarding: @@ -782,6 +783,18 @@ def body_fun(i): with self.assertRaisesRegex(KeyReuseError, "while_loop cond function leads to key reuse"): self.check_key_reuse(f, 0) + def test_pjit_consumed_input(self): + @jax.jit + def g(key, x): # doesn't consume key + return x + + def f(seed): + key = jax.random.key(seed) + x = jax.random.bits(key) + return g(key, x) + + self.check_key_reuse(f, 0) + class KeyReuseIntegrationTestSimple(KeyReuseIntegrationTest): use_forwarding = False