Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Accept ValueError for JAX backend tolist fallback #1746

Merged
merged 2 commits into from
Jan 18, 2022

Conversation

matthewfeickert
Copy link
Member

@matthewfeickert matthewfeickert commented Jan 18, 2022

Description

In JAX v0.2.27 the error raised for trying to pass a sequence as a list includes a ValueError

>>> import jax
>>> import jax.numpy as jnp
>>> jax.__version__
'0.2.27'
>>> jnp.asarray([[1, 2], 3, [4]])
TypeError: int() argument must be a string, a bytes-like object or a number, not 'list'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/.../site-packages/jax/_src/numpy/lax_numpy.py", line 3648, in asarray
    return array(a, dtype=dtype, copy=False, order=order)
  File "/.../site-packages/jax/_src/numpy/lax_numpy.py", line 3606, in array
    out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False)
ValueError: setting an array element with a sequence.

To handle this, also accept ValueError as a valid exception when falling back to list.

Checklist Before Requesting Reviewer

  • Tests are passing
  • "WIP" removed from the title of the pull request
  • Selected an Assignee for the PR to be responsible for the log summary

Before Merging

For the PR Assignees:

  • Summarize commit messages into a comprehensive review of the PR
In JAX v0.2.27 the error raised for trying to pass a sequence as a list
includes a ValueError

>>> import jax
>>> import jax.numpy as jnp
>>> jax.__version__
'0.2.27'
>>> jnp.asarray([[1, 2], 3, [4]])
TypeError: int() argument must be a string, a bytes-like object or a number, not 'list'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/.../site-packages/jax/_src/numpy/lax_numpy.py", line 3648, in asarray
    return array(a, dtype=dtype, copy=False, order=order)
  File "/.../site-packages/jax/_src/numpy/lax_numpy.py", line 3606, in array
    out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False)
ValueError: setting an array element with a sequence.

To handle this, also accept ValueError as a valid exception when falling back
to list.

In JAX v0.2.27 the error raised for trying to pass a sequence as a list
includes a ValueError

>>> import jax
>>> import jax.numpy as jnp
>>> jax.__version__
'0.2.27'
>>> jnp.asarray([[1, 2], 3, [4]])
TypeError: int() argument must be a string, a bytes-like object or a number, not 'list'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/.../site-packages/jax/_src/numpy/lax_numpy.py", line 3648, in asarray
    return array(a, dtype=dtype, copy=False, order=order)
  File "/.../site-packages/jax/_src/numpy/lax_numpy.py", line 3606, in array
    out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False)
ValueError: setting an array element with a sequence.

To handle this, also accept ValueError as a valid exception when falling back
to list.
@matthewfeickert matthewfeickert added the fix A bug fix label Jan 18, 2022
@matthewfeickert matthewfeickert self-assigned this Jan 18, 2022
@codecov
Copy link

codecov bot commented Jan 18, 2022

Codecov Report

Merging #1746 (b177494) into master (abde607) will not change coverage.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #1746   +/-   ##
=======================================
  Coverage   98.12%   98.12%           
=======================================
  Files          64       64           
  Lines        4270     4270           
  Branches      683      683           
=======================================
  Hits         4190     4190           
  Misses         46       46           
  Partials       34       34           
Flag Coverage Δ
contrib 26.25% <0.00%> (ø)
doctest 60.58% <0.00%> (ø)
unittests 96.18% <100.00%> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
src/pyhf/tensor/jax_backend.py 98.58% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update abde607...b177494. Read the comment docs.

@kratsg kratsg changed the title fix: Accept ValueErorr for JAX backend tolist fallback fix: Accept ValueError for JAX backend tolist fallback Jan 18, 2022
@matthewfeickert matthewfeickert merged commit 3c3d2db into master Jan 18, 2022
@matthewfeickert matthewfeickert deleted the fix/except-ValueError-for-jax branch January 18, 2022 21:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fix A bug fix
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants