Skip to content

Commit

Permalink
add test for #7155, fixes #7155
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 17, 2023
1 parent c2d5527 commit 00dc1f8
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4103,6 +4103,33 @@ def __call__(self, y):

jax.make_jaxpr(Foo(1))(3) # don't crash

def test_inner_jit_function_retracing(self):
# https://github.com/google/jax/issues/7155
inner_count = outer_count = 0

@jax.jit
def inner_fn(state):
nonlocal inner_count
inner_count += 1
return 2*state

@jax.jit
def outer_fn(x):
nonlocal outer_count
outer_count += 1
old_x = x
for _ in range(10):
x = inner_fn(x)
x = x + old_x
return x

state = jnp.arange(5, dtype=jnp.uint32)
inner_fn(state)
outer_fn(state)

self.assertEqual(inner_count, 1)
self.assertEqual(outer_count, 1)


class RematTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 00dc1f8

Please sign in to comment.