From d7a2706d87eb3c746d9d69a5160d950bff12d721 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Fri, 12 Nov 2021 09:26:34 -0600 Subject: [PATCH] refactor: Use jax.numpy for JAX backend tensorlib.tolist (#1138) * Use jax.numpy.tolist to provide the tolist method for the JAX backend - Note that NumPy dependency can never be removed as JAX depends on NumPy for JAX to NumPy conversion --- src/pyhf/tensor/jax_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pyhf/tensor/jax_backend.py b/src/pyhf/tensor/jax_backend.py index f6466357b2..0b183bc85e 100644 --- a/src/pyhf/tensor/jax_backend.py +++ b/src/pyhf/tensor/jax_backend.py @@ -179,8 +179,8 @@ def conditional(self, predicate, true_callable, false_callable): def tolist(self, tensor_in): try: - return np.asarray(tensor_in).tolist() - except AttributeError: + return jnp.asarray(tensor_in).tolist() + except TypeError: if isinstance(tensor_in, list): return tensor_in raise