Skip to content

Commit

Permalink
Merge pull request #3684 from cringeyburger/issue-3683-update-jax-imp…
Browse files Browse the repository at this point in the history
…orts

Update JAX Imports
  • Loading branch information
agriyakhetarpal committed Jan 3, 2024
2 parents 3bd05c4 + ca74358 commit b2e852e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
3 changes: 1 addition & 2 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@

if pybamm.have_jax():
import jax
from jax.config import config

platform = jax.lib.xla_bridge.get_backend().platform.casefold()
if platform != "metal":
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)


class JaxCooMatrix:
Expand Down
3 changes: 1 addition & 2 deletions pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
from jax import core, dtypes
from jax.extend import linear_util as lu
from jax.api_util import flatten_fun_nokwargs
from jax.config import config
from jax.flatten_util import ravel_pytree
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from jax.util import cache, safe_map, split_list

platform = jax.lib.xla_bridge.get_backend().platform.casefold()
if platform != "metal":
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

MAX_ORDER = 5
NEWTON_MAXITER = 4
Expand Down

0 comments on commit b2e852e

Please sign in to comment.