From 91a58b75d26fbd003c99df1c0e0fa751b1635d8d Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Thu, 31 Aug 2023 15:52:27 -0700 Subject: [PATCH] Add jax_utils --- keras_core/utils/jax_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 keras_core/utils/jax_utils.py diff --git a/keras_core/utils/jax_utils.py b/keras_core/utils/jax_utils.py new file mode 100644 index 000000000..3bb15cc5f --- /dev/null +++ b/keras_core/utils/jax_utils.py @@ -0,0 +1,9 @@ +from keras_core import backend + + +def is_in_jax_tracing_scope(): + if backend.backend() == "jax": + x = backend.numpy.ones(()) + if x.__class__.__name__ == "DynamicJaxprTracer": + return True + return False