Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a condition for jaxlib imports #1803

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

- Raise error when trying to convert an `Interpolant` with the "pchip" interpolator to CasADI ([#1791](https://github.com/pybamm-team/PyBaMM/pull/1791))
- Raise error if `Concatenation` is used directly with `Variable` objects (`concatenation` should be used instead) ([#1789](https://github.com/pybamm-team/PyBaMM/pull/1789))
- Made jax and the PyBaMM JaxSolver optional ([#1767](https://github.com/pybamm-team/PyBaMM/pull/1767))
- Made jax, jaxlib and the PyBaMM JaxSolver optional ([#1767](https://github.com/pybamm-team/PyBaMM/pull/1767))

# [v21.10](https://github.com/pybamm-team/PyBaMM/tree/v21.9) - 2021-10-31

Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class JaxCooMatrix:
def __init__(self, row, col, data, shape):
if not pybamm.have_jax():
raise ModuleNotFoundError(
"Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
"Jax or jaxlib is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
)

self.row = jax.numpy.array(row)
Expand Down Expand Up @@ -547,7 +547,7 @@ class EvaluatorJax:
def __init__(self, symbol):
if not pybamm.have_jax():
raise ModuleNotFoundError(
"Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
"Jax or jaxlib is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
)

constants, python_str = pybamm.to_python(symbol, debug=False, output_jax=True)
Expand Down
2 changes: 1 addition & 1 deletion pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6, mass=None):
"""
if not pybamm.have_jax():
raise ModuleNotFoundError(
"Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
"Jax or jaxlib is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
)

def _check_arg(arg):
Expand Down
2 changes: 1 addition & 1 deletion pybamm/solvers/jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
):
if not pybamm.have_jax():
raise ModuleNotFoundError(
"Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
"Jax or jaxlib is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
)

# note: bdf solver itself calculates consistent initial conditions so can set
Expand Down
6 changes: 4 additions & 2 deletions pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,10 @@ def get_parameters_filepath(path):


def have_jax():
"""Check if jax is installed"""
return importlib.util.find_spec("jax") is not None
"""Check if jax and jaxlib are installed"""
return (importlib.util.find_spec("jax") is not None) and (
importlib.util.find_spec("jaxlib") is not None
Copy link
Member Author

@Saransh-cpp Saransh-cpp Nov 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another approach would be to create a new have_jaxlib function to carry out this condition, but I am not sure if that would be better

With have_jaxlib we will be able to point out if jaxlib is not installed. Otherwise, I guess the exceptions would be "Jax or jaxlib not installed"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating have_jaxlib would complicate things unnecessarily

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

)


def install_jax():
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def test_solver_citations(self):
pybamm.IDAKLUSolver()
self.assertIn("Hindmarsh2005", citations._papers_to_cite)

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def test_jax_citations(self):
citations = pybamm.citations
citations._reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def test_evaluator_python(self):
result = evaluator.evaluate(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def test_find_symbols_jax(self):
# test sparse conversion
constant_symbols = OrderedDict()
Expand All @@ -472,7 +472,7 @@ def test_find_symbols_jax(self):
list(constant_symbols.values())[0].toarray(), A.entries.toarray()
)

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def test_evaluator_jax(self):
a = pybamm.StateVector(slice(0, 1))
b = pybamm.StateVector(slice(1, 2))
Expand Down Expand Up @@ -634,7 +634,7 @@ def test_evaluator_jax(self):
result = evaluator.evaluate(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def test_evaluator_jax_jacobian(self):
a = pybamm.StateVector(slice(0, 1))
y_tests = [np.array([[2.0]]), np.array([[1.0]]), np.array([1.0])]
Expand All @@ -649,23 +649,23 @@ def test_evaluator_jax_jacobian(self):
result_true = evaluator_jac.evaluate(t=None, y=y)
np.testing.assert_allclose(result_test, result_true)

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def test_evaluator_jax_debug(self):
a = pybamm.StateVector(slice(0, 1))
expr = a ** 2
y_test = np.array([[2.0], [3.0]])
evaluator = pybamm.EvaluatorJax(expr)
evaluator.debug(y=y_test)

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def test_evaluator_jax_inputs(self):
a = pybamm.InputParameter("a")
expr = a ** 2
evaluator = pybamm.EvaluatorJax(expr)
result = evaluator.evaluate(inputs={"a": 2})
self.assertEqual(result, 4)

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def test_jax_coo_matrix(self):
A = pybamm.JaxCooMatrix([0, 1], [0, 1], [1.0, 2.0], (2, 2))
Adense = jax.numpy.array([[1.0, 0], [0, 2.0]])
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def test_sensitivities(self):
def exact_diff_a(y, a, b):
return np.array([[y[0] ** 2 + 2 * a], [y[0]]])

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def exact_diff_b(y, a, b):
return np.array([[y[0]], [0]])

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_ida_roberts_klu(self):
true_solution = 0.1 * solution.t
np.testing.assert_array_almost_equal(solution.y[0, :], true_solution)

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def test_ida_roberts_klu_sensitivities(self):
# this test implements a python version of the ida Roberts
# example provided in sundials
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax


@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
class TestJaxBDFSolver(unittest.TestCase):
def test_solver(self):
# Create model
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax


@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
class TestJaxSolver(unittest.TestCase):
def test_model_solver(self):
# Create model
Expand Down