Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a condition for jaxlib imports #1803

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,10 @@ def get_parameters_filepath(path):


def have_jax():
"""Check if jax is installed"""
return importlib.util.find_spec("jax") is not None
"""Check if jax and jaxlib are installed"""
return (importlib.util.find_spec("jax") is not None) and (
importlib.util.find_spec("jaxlib") is not None
Copy link
Member Author

@Saransh-cpp Saransh-cpp Nov 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another approach would be to create a new have_jaxlib function to carry out this condition, but I am not sure if that would be better

With have_jaxlib we will be able to point out if jaxlib is not installed. Otherwise, I guess the exceptions would be "Jax or jaxlib not installed"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating have_jaxlib would complicate things unnecessarily

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

)


def install_jax():
Expand Down