From 9a25d56cc849b9db653f395b8a63c4f7aca947e0 Mon Sep 17 00:00:00 2001 From: Saransh Date: Sat, 13 Nov 2021 17:47:37 +0530 Subject: [PATCH 1/2] Add a condition for jaxlib imports --- pybamm/util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pybamm/util.py b/pybamm/util.py index 0bb8b640e8..90a06ef702 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -343,8 +343,10 @@ def get_parameters_filepath(path): def have_jax(): - """Check if jax is installed""" - return importlib.util.find_spec("jax") is not None + """Check if jax and jaxlib are installed""" + return (importlib.util.find_spec("jax") is not None) and ( + importlib.util.find_spec("jaxlib") is not None + ) def install_jax(): From 70b582656418f1ca429182d057cbdacfe129a96d Mon Sep 17 00:00:00 2001 From: Saransh Date: Sat, 13 Nov 2021 18:13:05 +0530 Subject: [PATCH 2/2] Update CHANGELOG + exceptions --- CHANGELOG.md | 2 +- pybamm/expression_tree/operations/evaluate_python.py | 4 ++-- pybamm/solvers/jax_bdf_solver.py | 2 +- pybamm/solvers/jax_solver.py | 2 +- tests/unit/test_citations.py | 2 +- .../test_operations/test_evaluate_python.py | 12 ++++++------ tests/unit/test_solvers/test_base_solver.py | 2 +- tests/unit/test_solvers/test_idaklu_solver.py | 2 +- tests/unit/test_solvers/test_jax_bdf_solver.py | 2 +- tests/unit/test_solvers/test_jax_solver.py | 2 +- 10 files changed, 16 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1904819f23..302d810e41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ - Raise error when trying to convert an `Interpolant` with the "pchip" interpolator to CasADI ([#1791](https://github.com/pybamm-team/PyBaMM/pull/1791)) - Raise error if `Concatenation` is used directly with `Variable` objects (`concatenation` should be used instead) ([#1789](https://github.com/pybamm-team/PyBaMM/pull/1789)) -- Made jax and the PyBaMM JaxSolver optional ([#1767](https://github.com/pybamm-team/PyBaMM/pull/1767)) +- Made jax, jaxlib and the PyBaMM JaxSolver optional ([#1767](https://github.com/pybamm-team/PyBaMM/pull/1767)) # [v21.10](https://github.com/pybamm-team/PyBaMM/tree/v21.9) - 2021-10-31 diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index 15c919858f..6949ccb3cd 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -40,7 +40,7 @@ class JaxCooMatrix: def __init__(self, row, col, data, shape): if not pybamm.have_jax(): raise ModuleNotFoundError( - "Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501 + "Jax or jaxlib is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501 ) self.row = jax.numpy.array(row) @@ -547,7 +547,7 @@ class EvaluatorJax: def __init__(self, symbol): if not pybamm.have_jax(): raise ModuleNotFoundError( - "Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501 + "Jax or jaxlib is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501 ) constants, python_str = pybamm.to_python(symbol, debug=False, output_jax=True) diff --git a/pybamm/solvers/jax_bdf_solver.py b/pybamm/solvers/jax_bdf_solver.py index 4d4c40a76f..41ef9e614a 100644 --- a/pybamm/solvers/jax_bdf_solver.py +++ b/pybamm/solvers/jax_bdf_solver.py @@ -1039,7 +1039,7 @@ def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6, mass=None): """ if not pybamm.have_jax(): raise ModuleNotFoundError( - "Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501 + "Jax or jaxlib is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501 ) def _check_arg(arg): diff --git a/pybamm/solvers/jax_solver.py b/pybamm/solvers/jax_solver.py index a0d3008207..1d117f0f95 100644 --- a/pybamm/solvers/jax_solver.py +++ b/pybamm/solvers/jax_solver.py @@ -60,7 +60,7 @@ def __init__( ): if not pybamm.have_jax(): raise ModuleNotFoundError( - "Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501 + "Jax or jaxlib is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501 ) # note: bdf solver itself calculates consistent initial conditions so can set diff --git a/tests/unit/test_citations.py b/tests/unit/test_citations.py index f443b12bd6..5e3a3ed7bf 100644 --- a/tests/unit/test_citations.py +++ b/tests/unit/test_citations.py @@ -254,7 +254,7 @@ def test_solver_citations(self): pybamm.IDAKLUSolver() self.assertIn("Hindmarsh2005", citations._papers_to_cite) - @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") + @unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") def test_jax_citations(self): citations = pybamm.citations citations._reset() diff --git a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py index ee7ef06bfb..67d624cb1d 100644 --- a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py +++ b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py @@ -459,7 +459,7 @@ def test_evaluator_python(self): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) - @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") + @unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") def test_find_symbols_jax(self): # test sparse conversion constant_symbols = OrderedDict() @@ -472,7 +472,7 @@ def test_find_symbols_jax(self): list(constant_symbols.values())[0].toarray(), A.entries.toarray() ) - @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") + @unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") def test_evaluator_jax(self): a = pybamm.StateVector(slice(0, 1)) b = pybamm.StateVector(slice(1, 2)) @@ -634,7 +634,7 @@ def test_evaluator_jax(self): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) - @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") + @unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") def test_evaluator_jax_jacobian(self): a = pybamm.StateVector(slice(0, 1)) y_tests = [np.array([[2.0]]), np.array([[1.0]]), np.array([1.0])] @@ -649,7 +649,7 @@ def test_evaluator_jax_jacobian(self): result_true = evaluator_jac.evaluate(t=None, y=y) np.testing.assert_allclose(result_test, result_true) - @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") + @unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") def test_evaluator_jax_debug(self): a = pybamm.StateVector(slice(0, 1)) expr = a ** 2 @@ -657,7 +657,7 @@ def test_evaluator_jax_debug(self): evaluator = pybamm.EvaluatorJax(expr) evaluator.debug(y=y_test) - @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") + @unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") def test_evaluator_jax_inputs(self): a = pybamm.InputParameter("a") expr = a ** 2 @@ -665,7 +665,7 @@ def test_evaluator_jax_inputs(self): result = evaluator.evaluate(inputs={"a": 2}) self.assertEqual(result, 4) - @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") + @unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") def test_jax_coo_matrix(self): A = pybamm.JaxCooMatrix([0, 1], [0, 1], [1.0, 2.0], (2, 2)) Adense = jax.numpy.array([[1.0, 0], [0, 2.0]]) diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index bd29be3f81..1bee2b4abc 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -345,7 +345,7 @@ def test_sensitivities(self): def exact_diff_a(y, a, b): return np.array([[y[0] ** 2 + 2 * a], [y[0]]]) - @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") + @unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") def exact_diff_b(y, a, b): return np.array([[y[0]], [0]]) diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index d17d07c63d..9b04b2366c 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -46,7 +46,7 @@ def test_ida_roberts_klu(self): true_solution = 0.1 * solution.t np.testing.assert_array_almost_equal(solution.y[0, :], true_solution) - @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") + @unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") def test_ida_roberts_klu_sensitivities(self): # this test implements a python version of the ida Roberts # example provided in sundials diff --git a/tests/unit/test_solvers/test_jax_bdf_solver.py b/tests/unit/test_solvers/test_jax_bdf_solver.py index b6e7ab92f1..8dcdb04a61 100644 --- a/tests/unit/test_solvers/test_jax_bdf_solver.py +++ b/tests/unit/test_solvers/test_jax_bdf_solver.py @@ -9,7 +9,7 @@ import jax -@unittest.skipIf(not pybamm.have_jax(), "jax is not installed") +@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") class TestJaxBDFSolver(unittest.TestCase): def test_solver(self): # Create model diff --git a/tests/unit/test_solvers/test_jax_solver.py b/tests/unit/test_solvers/test_jax_solver.py index e9956e4295..0692f5a93e 100644 --- a/tests/unit/test_solvers/test_jax_solver.py +++ b/tests/unit/test_solvers/test_jax_solver.py @@ -9,7 +9,7 @@ import jax -@unittest.skipIf(not pybamm.have_jax(), "jax is not installed") +@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") class TestJaxSolver(unittest.TestCase): def test_model_solver(self): # Create model