Skip to content

Commit

Permalink
fix: Accept ValueError for JAX backend tolist fallback (#1746)
Browse files Browse the repository at this point in the history
* In JAX v0.2.27 the error raised for trying to pass a sequence as a list
includes a ValueError

>>> import jax
>>> import jax.numpy as jnp
>>> jax.__version__
'0.2.27'
>>> jnp.asarray([[1, 2], 3, [4]])
TypeError: int() argument must be a string, a bytes-like object or a number, not 'list'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/.../site-packages/jax/_src/numpy/lax_numpy.py", line 3648, in asarray
    return array(a, dtype=dtype, copy=False, order=order)
  File "/.../site-packages/jax/_src/numpy/lax_numpy.py", line 3606, in array
    out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False)
ValueError: setting an array element with a sequence.

To handle this, also accept ValueError as a valid exception when falling back
to list.
  • Loading branch information
matthewfeickert authored Jan 18, 2022
1 parent abde607 commit 3c3d2db
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/pyhf/tensor/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def conditional(self, predicate, true_callable, false_callable):
def tolist(self, tensor_in):
try:
return jnp.asarray(tensor_in).tolist()
except TypeError:
except (TypeError, ValueError):
if isinstance(tensor_in, list):
return tensor_in
raise
Expand Down

0 comments on commit 3c3d2db

Please sign in to comment.