Skip to content

Commit

Permalink
Add jax_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Aug 31, 2023
1 parent 14b3755 commit 91a58b7
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions keras_core/utils/jax_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from keras_core import backend


def is_in_jax_tracing_scope():
if backend.backend() == "jax":
x = backend.numpy.ones(())
if x.__class__.__name__ == "DynamicJaxprTracer":
return True
return False

0 comments on commit 91a58b7

Please sign in to comment.