From 3c3d2dba4f91d0affc8d66e6a8fc549c4b1deaf2 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Tue, 18 Jan 2022 15:13:29 -0600 Subject: [PATCH] fix: Accept ValueError for JAX backend `tolist` fallback (#1746) * In JAX v0.2.27 the error raised for trying to pass a sequence as a list includes a ValueError >>> import jax >>> import jax.numpy as jnp >>> jax.__version__ '0.2.27' >>> jnp.asarray([[1, 2], 3, [4]]) TypeError: int() argument must be a string, a bytes-like object or a number, not 'list' The above exception was the direct cause of the following exception: Traceback (most recent call last): File "", line 1, in File "/.../site-packages/jax/_src/numpy/lax_numpy.py", line 3648, in asarray return array(a, dtype=dtype, copy=False, order=order) File "/.../site-packages/jax/_src/numpy/lax_numpy.py", line 3606, in array out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False) ValueError: setting an array element with a sequence. To handle this, also accept ValueError as a valid exception when falling back to list. --- src/pyhf/tensor/jax_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pyhf/tensor/jax_backend.py b/src/pyhf/tensor/jax_backend.py index dce701ffe0..db3b6dcca9 100644 --- a/src/pyhf/tensor/jax_backend.py +++ b/src/pyhf/tensor/jax_backend.py @@ -180,7 +180,7 @@ def conditional(self, predicate, true_callable, false_callable): def tolist(self, tensor_in): try: return jnp.asarray(tensor_in).tolist() - except TypeError: + except (TypeError, ValueError): if isinstance(tensor_in, list): return tensor_in raise