diff --git a/tests/api_test.py b/tests/api_test.py index e404ba8151b4..42c03e80bf36 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4103,6 +4103,33 @@ def __call__(self, y): jax.make_jaxpr(Foo(1))(3) # don't crash + def test_inner_jit_function_retracing(self): + # https://github.com/google/jax/issues/7155 + inner_count = outer_count = 0 + + @jax.jit + def inner_fn(state): + nonlocal inner_count + inner_count += 1 + return 2*state + + @jax.jit + def outer_fn(x): + nonlocal outer_count + outer_count += 1 + old_x = x + for _ in range(10): + x = inner_fn(x) + x = x + old_x + return x + + state = jnp.arange(5, dtype=jnp.uint32) + inner_fn(state) + outer_fn(state) + + self.assertEqual(inner_count, 1) + self.assertEqual(outer_count, 1) + class RematTest(jtu.JaxTestCase):