Skip to content

Commit

Permalink
Merge pull request #1803 from Saransh-cpp/issue-1801-jaxlib-imports
Browse files Browse the repository at this point in the history
Added a condition for `jaxlib` imports
  • Loading branch information
valentinsulzer committed Nov 13, 2021
2 parents 4aa4776 + 70b5826 commit d69db28
Show file tree
Hide file tree
Showing 11 changed files with 20 additions and 18 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

- Raise error when trying to convert an `Interpolant` with the "pchip" interpolator to CasADI ([#1791](https://github.com/pybamm-team/PyBaMM/pull/1791))
- Raise error if `Concatenation` is used directly with `Variable` objects (`concatenation` should be used instead) ([#1789](https://github.com/pybamm-team/PyBaMM/pull/1789))
- Made jax and the PyBaMM JaxSolver optional ([#1767](https://github.com/pybamm-team/PyBaMM/pull/1767))
- Made jax, jaxlib and the PyBaMM JaxSolver optional ([#1767](https://github.com/pybamm-team/PyBaMM/pull/1767))

# [v21.10](https://github.com/pybamm-team/PyBaMM/tree/v21.9) - 2021-10-31

Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class JaxCooMatrix:
def __init__(self, row, col, data, shape):
if not pybamm.have_jax():
raise ModuleNotFoundError(
"Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
"Jax or jaxlib is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
)

self.row = jax.numpy.array(row)
Expand Down Expand Up @@ -547,7 +547,7 @@ class EvaluatorJax:
def __init__(self, symbol):
if not pybamm.have_jax():
raise ModuleNotFoundError(
"Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
"Jax or jaxlib is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
)

constants, python_str = pybamm.to_python(symbol, debug=False, output_jax=True)
Expand Down
2 changes: 1 addition & 1 deletion pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6, mass=None):
"""
if not pybamm.have_jax():
raise ModuleNotFoundError(
"Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
"Jax or jaxlib is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
)

def _check_arg(arg):
Expand Down
2 changes: 1 addition & 1 deletion pybamm/solvers/jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
):
if not pybamm.have_jax():
raise ModuleNotFoundError(
"Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
"Jax or jaxlib is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
)

# note: bdf solver itself calculates consistent initial conditions so can set
Expand Down
6 changes: 4 additions & 2 deletions pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,10 @@ def get_parameters_filepath(path):


def have_jax():
"""Check if jax is installed"""
return importlib.util.find_spec("jax") is not None
"""Check if jax and jaxlib are installed"""
return (importlib.util.find_spec("jax") is not None) and (
importlib.util.find_spec("jaxlib") is not None
)


def install_jax():
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def test_solver_citations(self):
pybamm.IDAKLUSolver()
self.assertIn("Hindmarsh2005", citations._papers_to_cite)

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def test_jax_citations(self):
citations = pybamm.citations
citations._reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def test_evaluator_python(self):
result = evaluator.evaluate(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def test_find_symbols_jax(self):
# test sparse conversion
constant_symbols = OrderedDict()
Expand All @@ -472,7 +472,7 @@ def test_find_symbols_jax(self):
list(constant_symbols.values())[0].toarray(), A.entries.toarray()
)

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def test_evaluator_jax(self):
a = pybamm.StateVector(slice(0, 1))
b = pybamm.StateVector(slice(1, 2))
Expand Down Expand Up @@ -634,7 +634,7 @@ def test_evaluator_jax(self):
result = evaluator.evaluate(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def test_evaluator_jax_jacobian(self):
a = pybamm.StateVector(slice(0, 1))
y_tests = [np.array([[2.0]]), np.array([[1.0]]), np.array([1.0])]
Expand All @@ -649,23 +649,23 @@ def test_evaluator_jax_jacobian(self):
result_true = evaluator_jac.evaluate(t=None, y=y)
np.testing.assert_allclose(result_test, result_true)

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def test_evaluator_jax_debug(self):
a = pybamm.StateVector(slice(0, 1))
expr = a ** 2
y_test = np.array([[2.0], [3.0]])
evaluator = pybamm.EvaluatorJax(expr)
evaluator.debug(y=y_test)

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def test_evaluator_jax_inputs(self):
a = pybamm.InputParameter("a")
expr = a ** 2
evaluator = pybamm.EvaluatorJax(expr)
result = evaluator.evaluate(inputs={"a": 2})
self.assertEqual(result, 4)

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def test_jax_coo_matrix(self):
A = pybamm.JaxCooMatrix([0, 1], [0, 1], [1.0, 2.0], (2, 2))
Adense = jax.numpy.array([[1.0, 0], [0, 2.0]])
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def test_sensitivities(self):
def exact_diff_a(y, a, b):
return np.array([[y[0] ** 2 + 2 * a], [y[0]]])

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def exact_diff_b(y, a, b):
return np.array([[y[0]], [0]])

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_ida_roberts_klu(self):
true_solution = 0.1 * solution.t
np.testing.assert_array_almost_equal(solution.y[0, :], true_solution)

@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
def test_ida_roberts_klu_sensitivities(self):
# this test implements a python version of the ida Roberts
# example provided in sundials
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax


@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
class TestJaxBDFSolver(unittest.TestCase):
def test_solver(self):
# Create model
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax


@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
class TestJaxSolver(unittest.TestCase):
def test_model_solver(self):
# Create model
Expand Down

0 comments on commit d69db28

Please sign in to comment.