Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Use jax.numpy for JAX backend tensorlib.tolist #1138

Merged
merged 4 commits into from
Nov 12, 2021

Conversation

matthewfeickert
Copy link
Member

@matthewfeickert matthewfeickert commented Oct 20, 2020

Description

Resolves #1137

In more recent releases of JAX the jax.numpy module is more filled out and it now supports jax.numpy.tolist, which was the only thing NumPy was used for previously.

NumPy must still be kept as JAX relies on numpy itself to convert from JAX DeviceArray to NumPy

def to_numpy(self, tensor_in):
"""
Convert the JAX tensor to a :class:`numpy.ndarray`.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> tensor
DeviceArray([[1., 2., 3.],
[4., 5., 6.]], dtype=float64)
>>> numpy_ndarray = pyhf.tensorlib.to_numpy(tensor)
>>> numpy_ndarray
array([[1., 2., 3.],
[4., 5., 6.]])
>>> type(numpy_ndarray)
<class 'numpy.ndarray'>
Args:
tensor_in (:obj:`tensor`): The input tensor object.
Returns:
:class:`numpy.ndarray`: The tensor converted to a NumPy ``ndarray``.
"""
return np.asarray(tensor_in, dtype=tensor_in.dtype)

Checklist Before Requesting Reviewer

  • Tests are passing
  • "WIP" removed from the title of the pull request
  • Selected an Assignee for the PR to be responsible for the log summary

Before Merging

For the PR Assignees:

  • Summarize commit messages into a comprehensive review of the PR
* Use jax.numpy.tolist to provide the tolist method for the JAX backend
   - Note that NumPy dependency can never be removed as JAX depends on
     NumPy for JAX to NumPy conversion

@matthewfeickert matthewfeickert added the refactor A code change that neither fixes a bug nor adds a feature label Oct 20, 2020
@matthewfeickert matthewfeickert self-assigned this Oct 20, 2020
@matthewfeickert matthewfeickert marked this pull request as draft October 20, 2020 19:57
@matthewfeickert matthewfeickert force-pushed the refactor/remove-numpy-from-jax-backend branch from b001ef3 to a786908 Compare November 12, 2021 00:16
@matthewfeickert matthewfeickert changed the title refactor: Only use jax.numpy for JAX backend refactor: Use jax.numpy for JAX backend tensorlib.tolist Nov 12, 2021
@matthewfeickert matthewfeickert marked this pull request as ready for review November 12, 2021 00:20
@matthewfeickert
Copy link
Member Author

I was hoping to remove NumPy entirely, but turns out that's not possible with JAX as JAX needs NumPy to convert to NumPy.

@codecov
Copy link

codecov bot commented Nov 12, 2021

Codecov Report

Merging #1138 (a786908) into master (2b2b281) will increase coverage by 0.04%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1138      +/-   ##
==========================================
+ Coverage   98.06%   98.10%   +0.04%     
==========================================
  Files          64       64              
  Lines        4228     4228              
  Branches      587      587              
==========================================
+ Hits         4146     4148       +2     
+ Misses         49       46       -3     
- Partials       33       34       +1     
Flag Coverage Δ
contrib 25.40% <0.00%> (ø)
doctest 61.21% <0.00%> (ø)
unittests 96.42% <100.00%> (+0.04%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
src/pyhf/tensor/jax_backend.py 98.58% <100.00%> (+1.41%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 2b2b281...a786908. Read the comment docs.

@matthewfeickert matthewfeickert merged commit d7a2706 into master Nov 12, 2021
@matthewfeickert matthewfeickert deleted the refactor/remove-numpy-from-jax-backend branch November 12, 2021 15:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
refactor A code change that neither fixes a bug nor adds a feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Remove use of NumPy from JAX backend
2 participants