Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename have_x to has_x to improve how logic reads #4398

Merged
merged 4 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

## Breaking changes

- Replaced `have_jax` with `has_jax`, `have_idaklu` with `has_idaklu`, and
`have_iree` with `has_iree` ([#4398](https://github.com/pybamm-team/PyBaMM/pull/4398))
- Remove deprecated function `pybamm_install_jax` ([#4362](https://github.com/pybamm-team/PyBaMM/pull/4362))
- Removed legacy python-IDAKLU solver. ([#4326](https://github.com/pybamm-team/PyBaMM/pull/4326))

Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/util.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ Utility functions

.. autofunction:: pybamm.load

.. autofunction:: pybamm.have_jax
.. autofunction:: pybamm.has_jax

.. autofunction:: pybamm.is_jax_compatible
2 changes: 1 addition & 1 deletion examples/scripts/compare_dae_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
casadi_sol = pybamm.CasadiSolver(atol=1e-8, rtol=1e-8).solve(model, t_eval)
solutions = [casadi_sol]

if pybamm.have_idaklu():
if pybamm.has_idaklu():
klu_sol = pybamm.IDAKLUSolver(atol=1e-8, rtol=1e-8).solve(model, t_eval)
solutions.append(klu_sol)
else:
Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from .util import (
get_parameters_filepath,
have_jax,
has_jax,
import_optional_dependency,
is_jax_compatible,
get_git_commit_info,
Expand Down Expand Up @@ -170,7 +170,7 @@
from .solvers.jax_bdf_solver import jax_bdf_integrate

from .solvers.idaklu_jax import IDAKLUJax
from .solvers.idaklu_solver import IDAKLUSolver, have_idaklu, have_iree
from .solvers.idaklu_solver import IDAKLUSolver, has_idaklu, has_iree

# Experiments
from .experiment.experiment import Experiment
Expand Down
6 changes: 3 additions & 3 deletions src/pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import pybamm

if pybamm.have_jax():
if pybamm.has_jax():
import jax

platform = jax.lib.xla_bridge.get_backend().platform.casefold()
Expand Down Expand Up @@ -43,7 +43,7 @@ class JaxCooMatrix:
def __init__(
self, row: ArrayLike, col: ArrayLike, data: ArrayLike, shape: tuple[int, int]
):
if not pybamm.have_jax(): # pragma: no cover
if not pybamm.has_jax(): # pragma: no cover
raise ModuleNotFoundError(
"Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver"
)
Expand Down Expand Up @@ -527,7 +527,7 @@ class EvaluatorJax:
"""

def __init__(self, symbol: pybamm.Symbol):
if not pybamm.have_jax(): # pragma: no cover
if not pybamm.has_jax(): # pragma: no cover
raise ModuleNotFoundError(
"Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver"
)
Expand Down
6 changes: 3 additions & 3 deletions src/pybamm/solvers/idaklu_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
except ImportError: # pragma: no cover
idaklu_spec = None

if pybamm.have_jax():
if pybamm.has_jax():
import jax
from jax import lax
from jax import numpy as jnp
Expand Down Expand Up @@ -57,11 +57,11 @@ def __init__(
calculate_sensitivities=True,
t_interp=None,
):
if not pybamm.have_jax():
if not pybamm.has_jax():
raise ModuleNotFoundError(
"Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver"
) # pragma: no cover
if not pybamm.have_idaklu():
if not pybamm.has_idaklu():
raise ModuleNotFoundError(
"IDAKLU is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html"
) # pragma: no cover
Expand Down
6 changes: 3 additions & 3 deletions src/pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import warnings


if pybamm.have_jax():
if pybamm.has_jax():
import jax
from jax import numpy as jnp

Expand All @@ -33,11 +33,11 @@
idaklu_spec = None


def have_idaklu():
def has_idaklu():
return idaklu_spec is not None


def have_iree():
def has_iree():
try:
import iree.compiler # noqa: F401

Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pybamm

if pybamm.have_jax():
if pybamm.has_jax():
import jax
import jax.numpy as jnp
from jax import core, dtypes
Expand Down Expand Up @@ -1007,7 +1007,7 @@ def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6, mass=None):
calculated state vector at each of the m time points

"""
if not pybamm.have_jax():
if not pybamm.has_jax():
raise ModuleNotFoundError(
"Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver"
)
Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/solvers/jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pybamm

if pybamm.have_jax():
if pybamm.has_jax():
import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(
extrap_tol=None,
extra_options=None,
):
if not pybamm.have_jax():
if not pybamm.has_jax():
raise ModuleNotFoundError(
"Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver"
)
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def get_parameters_filepath(path):
return os.path.join(pybamm.__path__[0], path)


def have_jax():
def has_jax():
"""
Check if jax and jaxlib are installed with the correct versions

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_sensitivities(self):
param = pybamm.ParameterValues("Ecker2015")
rtol = 1e-6
atol = 1e-6
if pybamm.have_idaklu():
if pybamm.has_idaklu():
solver = pybamm.IDAKLUSolver(rtol=rtol, atol=atol)
else:
solver = pybamm.CasadiSolver(rtol=rtol, atol=atol)
Expand Down Expand Up @@ -53,7 +53,7 @@ def test_optimisations(self):
to_python = optimtest.evaluate_model(to_python=True)
np.testing.assert_array_almost_equal(original, to_python)

if pybamm.have_jax():
if pybamm.has_jax():
to_jax = optimtest.evaluate_model(to_jax=True)
np.testing.assert_array_almost_equal(original, to_jax)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_optimisations(self):
to_python = optimtest.evaluate_model(to_python=True)
np.testing.assert_array_almost_equal(original, to_python)

if pybamm.have_jax():
if pybamm.has_jax():
to_jax = optimtest.evaluate_model(to_jax=True)
np.testing.assert_array_almost_equal(original, to_jax)

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_solvers/test_idaklu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np


@pytest.mark.skipif(not pybamm.have_idaklu(), reason="idaklu solver is not installed")
@pytest.mark.skipif(not pybamm.has_idaklu(), reason="idaklu solver is not installed")
class TestIDAKLUSolver:
def test_on_spme(self):
model = pybamm.lithium_ion.SPMe()
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,14 +423,14 @@ def test_solver_citations(self):
assert "Virtanen2020" in citations._papers_to_cite
assert "Virtanen2020" in citations._citation_tags.keys()

if pybamm.have_idaklu():
if pybamm.has_idaklu():
citations._reset()
assert "Hindmarsh2005" not in citations._papers_to_cite
pybamm.IDAKLUSolver()
assert "Hindmarsh2005" in citations._papers_to_cite
assert "Hindmarsh2005" in citations._citation_tags.keys()

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_jax_citations(self):
citations = pybamm.citations
citations._reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_run_experiment_multiple_times(self):
sol1["Voltage [V]"].data, sol2["Voltage [V]"].data
)

@unittest.skipIf(not pybamm.have_idaklu(), "idaklu solver is not installed")
@unittest.skipIf(not pybamm.has_idaklu(), "idaklu solver is not installed")
def test_run_experiment_cccv_solvers(self):
experiment_2step = pybamm.Experiment(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections import OrderedDict
import re

if pybamm.have_jax():
if pybamm.has_jax():
import jax
from tests import (
function_test,
Expand Down Expand Up @@ -446,7 +446,7 @@ def test_evaluator_python(self):
result = evaluator(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_find_symbols_jax(self):
# test sparse conversion
constant_symbols = OrderedDict()
Expand All @@ -459,7 +459,7 @@ def test_find_symbols_jax(self):
next(iter(constant_symbols.values())).toarray(), A.entries.toarray()
)

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_evaluator_jax(self):
a = pybamm.StateVector(slice(0, 1))
b = pybamm.StateVector(slice(1, 2))
Expand Down Expand Up @@ -621,7 +621,7 @@ def test_evaluator_jax(self):
result = evaluator(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_evaluator_jax_jacobian(self):
a = pybamm.StateVector(slice(0, 1))
y_tests = [np.array([[2.0]]), np.array([[1.0]]), np.array([1.0])]
Expand All @@ -636,7 +636,7 @@ def test_evaluator_jax_jacobian(self):
result_true = evaluator_jac(t=None, y=y)
np.testing.assert_allclose(result_test, result_true)

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_evaluator_jax_jvp(self):
a = pybamm.StateVector(slice(0, 1))
y_tests = [np.array([[2.0]]), np.array([[1.0]]), np.array([1.0])]
Expand All @@ -656,23 +656,23 @@ def test_evaluator_jax_jvp(self):
np.testing.assert_allclose(result_test, result_true)
np.testing.assert_allclose(result_test_times_v, result_true_times_v)

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_evaluator_jax_debug(self):
a = pybamm.StateVector(slice(0, 1))
expr = a**2
y_test = np.array([2.0, 3.0])
evaluator = pybamm.EvaluatorJax(expr)
evaluator.debug(y=y_test)

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_evaluator_jax_inputs(self):
a = pybamm.InputParameter("a")
expr = a**2
evaluator = pybamm.EvaluatorJax(expr)
result = evaluator(inputs={"a": 2})
assert result == 4

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_evaluator_jax_demotion(self):
for demote in [True, False]:
pybamm.demote_expressions_to_32bit = demote # global flag
Expand Down Expand Up @@ -734,7 +734,7 @@ def test_evaluator_jax_demotion(self):
assert all(str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.col)
pybamm.demote_expressions_to_32bit = False

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_jax_coo_matrix(self):
A = pybamm.JaxCooMatrix([0, 1], [0, 1], [1.0, 2.0], (2, 2))
Adense = jax.numpy.array([[1.0, 0], [0, 2.0]])
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_solvers/test_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,12 @@ def test_multiprocess_context(self):
assert solver.get_platform_context("Linux") == "fork"
assert solver.get_platform_context("Darwin") == "fork"

@unittest.skipIf(not pybamm.have_idaklu(), "idaklu solver is not installed")
@unittest.skipIf(not pybamm.has_idaklu(), "idaklu solver is not installed")
def test_sensitivities(self):
def exact_diff_a(y, a, b):
return np.array([[y[0] ** 2 + 2 * a], [y[0]]])

@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
@unittest.skipIf(not pybamm.has_jax(), "jax or jaxlib is not installed")
def exact_diff_b(y, a, b):
return np.array([[y[0]], [0]])

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_solvers/test_idaklu_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import unittest

testcase = []
if pybamm.have_idaklu() and pybamm.have_jax():
if pybamm.has_idaklu() and pybamm.has_jax():
from jax.tree_util import tree_flatten
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -87,7 +87,7 @@ def no_jit(f):

# Check the interface throws an appropriate error if either IDAKLU or JAX not available
@unittest.skipIf(
pybamm.have_idaklu() and pybamm.have_jax(),
pybamm.has_idaklu() and pybamm.has_jax(),
"Both IDAKLU and JAX are available",
)
class TestIDAKLUJax_NoJax(unittest.TestCase):
Expand All @@ -97,7 +97,7 @@ def test_instantiate_fails(self):


@unittest.skipIf(
not pybamm.have_idaklu() or not pybamm.have_jax(),
not pybamm.has_idaklu() or not pybamm.has_jax(),
"IDAKLU Solver and/or JAX are not available",
)
class TestIDAKLUJax(unittest.TestCase):
Expand Down
Loading
Loading