Skip to content

Commit

Permalink
fix: Support JAX array API before and after JAX v0.4.1 (#2280)
Browse files Browse the repository at this point in the history
* Add try-except block to determine what JAX array API is available and use this
  information to set the jax backend array_type and array_subtype.
* Backport components of:
   - PR #2079
   - PR #2085
  • Loading branch information
matthewfeickert authored Aug 16, 2023
1 parent 794052f commit b3c5ead
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
12 changes: 10 additions & 2 deletions src/pyhf/tensor/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@

log = logging.getLogger(__name__)

# v0.7.x backport hack
_old_jax_version = False
try:
from jax import Array
except ImportError:
# jax.Array added in jax v0.4.1
_old_jax_version = True


class _BasicPoisson:
def __init__(self, rate):
Expand Down Expand Up @@ -54,10 +62,10 @@ class jax_backend:
__slots__ = ['name', 'precision', 'dtypemap', 'default_do_grad']

#: The array type for jax
array_type = jnp.DeviceArray
array_type = jnp.DeviceArray if _old_jax_version else Array

#: The array content type for jax
array_subtype = jnp.DeviceArray
array_subtype = jnp.DeviceArray if _old_jax_version else Array

def __init__(self, **kwargs):
self.name = 'jax'
Expand Down
2 changes: 1 addition & 1 deletion tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ def validate_hypotest(
{'init_pars': 2, 'par_bounds': 2},
1.0,
"q",
2e-9,
3e-6,
"asymptotics",
),
(
Expand Down

0 comments on commit b3c5ead

Please sign in to comment.