Skip to content

Commit

Permalink
Merge pull request #19665 from jakevdp:disable-jit-doc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604400245
  • Loading branch information
jax authors committed Feb 5, 2024
2 parents e224c3d + 82611eb commit be99451
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ def disable_jit(disable: bool = True):
`cond` functions passed to higher-level primitives like :func:`~jax.lax.scan` and
:func:`~jax.lax.while_loop`, JIT used in implementations of :mod:`jax.numpy` functions,
and any other case where :func:`jit` is used within an API's implementation.
Note however that even under `disable_jit`, individual primitive operations
will still be compiled by XLA as in normal eager op-by-op execution.
Values that have a data dependence on the arguments to a jitted function are
traced and abstracted. For example, an abstract value may be a
Expand Down

0 comments on commit be99451

Please sign in to comment.