Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MLIR-based expression evaluation #4199

Merged
merged 79 commits into from
Jul 15, 2024

Conversation

jsbrittain
Copy link
Contributor

@jsbrittain jsbrittain commented Jun 19, 2024

Description

Add a new expression evaluation backend to the IDAKLU solver. MLIR expression evaluation is now supported by lowering PyBaMM's Jax-based expressions into MLIR, which are then compiled and executed as part of the IDAKLU solver using IREE.

To enable the IREE/MLIR backend, set the (new) PYBAMM_IDAKLU_EXPR_IREE compiler flag ON via an environment variable and install PyBaMM using the developer method (by default PYBAMM_IDAKLU_EXPR_IREE is turned OFF):

export PYBAMM_IDAKLU_EXPR_IREE=ON
nox -e pybamm-requires && nox -e dev

Expression evaluation in IDAKLU is enabled by constructing the model using Jax expressions (model.convert_to_format="jax") and setting the solver backend (jax_evaluator="iree"). Example:

import pybamm
import numpy as np

model = pybamm.lithium_ion.SPM()
model.convert_to_format = "jax"
geometry = model.default_geometry
param = model.default_parameter_values
param.process_model(model)
param.process_geometry(model.default_geometry)
mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

solver = pybamm.IDAKLUSolver(
	root_method="hybr",  # change from default ("casadi")
	options={"jax_evaluator": "iree"}
)
solution = solver.solve(model, np.linspace(0, 3600, 2500))

print(solution["Voltage [V]"].entries[:100])

Note that IREE currently only supports single-precision floating-point operations, which requires the model to be demoted from 64-bit to 32-bit precision before the solver can run. This is handled within the solver logic, but the operation is performed in-place on the PyBaMM battery model (we display a warning when run). Operating at lower precision requires tolerances to be relaxed for convergence on larger [e.g. DFN] models, and leads to memory transfers and type casting in the solver which are currently causing slow-downs (at least until 64-bit computation is natively supported).

Comparative performance on the above SPM problem on an Apple M2 Macbook Pro (with events=[] to allow comparison to the JaxSolver):

  • IDAKLU-IREE took 1.4 secs (1.3 secs to demote and compile the expressions; <0.1 secs for each subsequent solve).
  • IDAKLU-Casadi: took 0.14 secs (<0.1 secs setup; <0.1 secs for each subsequent solve).
  • JaxSolver [BDF]: took 1.0 secs (0.9 secs compilation; 0.1 secs for each subsequent solve).

Substituting a DFN model (and reducing atol = 1e-1) the times become:

  • IDAKLU-IREE took 12.1 secs (8.7 secs to demote and compile the expressions; 3.4 secs for each subsequent solve).
  • IDAKLU-Casadi: took 0.6 secs (0.3 secs setup; 0.3 secs for each subsequent solve).
  • JaxSolver [BDF]: took 22.8 secs (7.6 secs compilation; 15.2 secs for each subsequent solve).

There is a noticeable performance deficit for the IDAKLU-MLIR solver compared to Casadi, due to 1) initial compilation of MLIR to bytecode, 2) demotion strategies, and 3) memory transfers casting between types in the solver. We anticipate improvements in the second and third points with native 64-bit IREE support, and as our IREE approach compiles on the model expressions (not the solver) compilation times quickly out-perform the JaxSolver with increasing model complexity / time steps (while also taking full advantage of the capabilities already provided by the IDAKLU solver, such as events). The IREE/MLIR approach offers a pathway to compiling expressions across a wide variety of backends, including metal and cuda, although additional code adjustment (principally host/device transfers) will be required before those can be supported.

Resolves #3826

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.

  • New feature (non-breaking change which adds functionality)
  • Optimization (back-end change that speeds up the code)
  • Bug fix (non-breaking change which fixes an issue)

Key checklist:

  • No style issues: $ pre-commit run (or $ nox -s pre-commit) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)
  • All tests pass: $ python run-tests.py --all (or $ nox -s tests)
  • The documentation builds: $ python run-tests.py --doctest (or $ nox -s doctests)

You can run integration tests, unit tests, and doctests together at once, using $ python run-tests.py --quick (or $ nox -s quick).

Further checks:

  • Code is commented, particularly in hard-to-understand areas
  • Tests added that prove fix is effective or that feature works

brosaplanella and others added 30 commits March 5, 2024 12:27
This amends the tag for the CodeCov GitHub Action from `4.1.0` to `v4.1.0`. This was a Dependabot error
jsbrittain and others added 2 commits June 25, 2024 18:27
Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>
Copy link
Contributor

@martinjrobins martinjrobins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @jsbrittain this looks excellent. I've made a few suggestions below, see what you think. also, is jax_evaluator="iree" needed for the options, it there a case where you want to convert to jax but not use the jax_evaluator?

pybamm/expression_tree/functions.py Outdated Show resolved Hide resolved
pybamm/expression_tree/operations/evaluate_python.py Outdated Show resolved Hide resolved
@@ -310,7 +310,7 @@ def find_symbols(

elif isinstance(symbol, pybamm.SparseStack):
if len(children_vars) == 1:
symbol_str = children_vars[0]
symbol_str = children_vars[0] # pragma: no cover
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again a test would be good for this. if you don't want to do it in this PR (understandable) then an issue for reminder would be good too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more below which I won't mention. Looks like we were not covering a lot of it :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite a few of these were introduced as 'indirect' changes during one of the commits. I will go back and see whether these can be relaxed again, or if new tests are needed...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've gone through and removed all of the 'no cover' statements that I inserted due to indirect changes in previous revisions. Codecov seems to have settled itself down now and is passing everything again. So we should be good.

pybamm/solvers/c_solvers/idaklu.cpp Outdated Show resolved Hide resolved
}

// Parse module name
std::regex module_name_regex("module @([^\\s]+)"); // Match until first whitespace
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does iree have a parser class that can you use to do this? (e.g. https://mlir.llvm.org/doxygen/classmlir_1_1detail_1_1Parser.html). This seems quite manual

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is proving surprisingly difficult to implement. I can see two options:

  1. working with IREE: the routines provided only seem to permit access to low-level memory details, using representations below the tensor syntax, which makes extraction difficult.
  2. working with llvm-mlir: we have to compile-in the llvm/mlir libraries alongside iree; set-up new contexts and register dialects in order to parse the mlir string. This approach has potential but is so far proving non-trivial to get working in harmony with the iree solution we already have in place.

As this is turning into a bit of a blocker, I have decided to refactor the existing module parsing code out into it's own ModuleParser class that can be modified or wholesale replaced as needed at a later date.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough. Happy to go with the separate parsing class for now.

pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp Outdated Show resolved Hide resolved
jsbrittain and others added 3 commits June 27, 2024 14:24
Co-authored-by: Martin Robinson <martinjrobins@gmail.com>
….cpp

Co-authored-by: Martin Robinson <martinjrobins@gmail.com>
@jsbrittain
Copy link
Contributor Author

thanks @jsbrittain this looks excellent. I've made a few suggestions below, see what you think. also, is jax_evaluator="iree" needed for the options, it there a case where you want to convert to jax but not use the jax_evaluator?

@martinjrobins yes, there is actually an existing python-idaklu interface that will run if we don't redirect using the (new) jax_evaluator option. I think it's a legacy item (idaklu/python.cpp) (it can be quite slow, even on these toy examples).

@martinjrobins
Copy link
Contributor

that reminds me, we should get rid of python-idaklu, I don't think it serves a useful purpose anymore. I'll add an issue

@agriyakhetarpal
Copy link
Member

Hi again, @jsbrittain#4205 is trying to migrate PyBaMM's package structure to an src layout from the current flat one, which will move all the files from pybamm/ into src/pybamm/. It has a lot of potential for delays because of inevitable merge conflicts across several PRs, so we plan to merge it as soon as stable v24.5 hits next week, just a heads up :)

Copy link
Contributor

@martinjrobins martinjrobins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jsbrittain for the changes, all looks good to me now. I'm happy to merge, but just need to check with #4249 how we're going to put these two PRs in.

@martinjrobins martinjrobins merged commit 7443243 into pybamm-team:develop Jul 15, 2024
26 checks passed
js1tr3 pushed a commit to js1tr3/PyBaMM that referenced this pull request Aug 12, 2024
* Fix CodeCov GHA workflow failure (pybamm-team#3845)

This amends the tag for the CodeCov GitHub Action from `4.1.0` to `v4.1.0`. This was a Dependabot error

* Begin refactor IDAKLU solver to support generalisable expressions [i.e. support more than just casadi functions]

* Continue refactor; compiles but with dlopen error on load; committing to test on another machine

* Refactor: Introduce Expression and ExpressionSet classes

* Restructure Expressions with generics folder and implementation-specific subfolders

* Separate Expression classes

* Template Expression class

* WIP: Subclass expressions (remove templates in order to generalise execution framework)

* Subclass expressions and remove unnecessary template arguments

* Isolate casadi functionality to Expression subclasses

* Add IREE expressions class

* Remove breakpoints

* Map input arguments to (reduced) call signature in MLIR

* Add support for inputs

* Support sensitivities

* Pre-commit

* Support output_variables

* Fix designator order error on linux; remove reference to ninja

* OS-invariant library loading

* Convert some pointer arrays to vectors; fixes sporadic crashes

* Fix bad memory allocations

* Tidy-up code

* Tidy up code

* Resolve jax/iree version numbers

* Conditional compilation of iree code

* Fix compiler variables

* Update noxfile sources list

* Pre-commit

* Make IREE tests conditional on IREE install

* Enable IREE in CI unit/integration testing

* Make demotion optional in idaklu_solver.py (still unsupported by IDAKLU)

* style: pre-commit fixes

* Fix expression tree test given change to bracketed expressions

* Codacy corrections and suppressions

* Enable IREE in unit testing

* style: pre-commit fixes

* Enable IREE in coverage testing

* Restrict IREE install to supported MacOS/Python versions

* Restrict IREE supported macOS installs

* Additional tests for IREE (demotion, output_variables, sensititivies)

* style: pre-commit fixes

* Additional tests (improve test coverage)

* Fix IREE unit test

* Fix sensitivities-sparsity bug; improve tests

* Fix tests

* Improve coverage

* Improve coverage (indirect)

* Suppress IREE warning on reload

* style: pre-commit fixes

* Fix codacy warning

* Fix noxfile docstring

* Update noxfile.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* Update pyproject.toml

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

* style: pre-commit fixes

* Update pybamm/solvers/c_solvers/idaklu.cpp

Co-authored-by: Martin Robinson <martinjrobins@gmail.com>

* Update pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp

Co-authored-by: Martin Robinson <martinjrobins@gmail.com>

* Try removing no-cover from expression_tree/function.py

* Remove C from class name

* Clarify Base Expression docstrings

* Remove build-time iree.compiler search in CMakeLists.txt

* Add install note to iree-compiler in pyproject.toml

* Add IREE dependencies to docs

* style: pre-commit fixes

* Refactor MLIR parsing into ModuleParser class

* style: pre-commit fixes

* Add codacy hints

* Codacy fix

* Remove no-cover statements

* style: pre-commit fixes

* Coverage fix

---------

Co-authored-by: Ferran Brosa Planella <Ferran.Brosa-Planella@warwick.ac.uk>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>
Co-authored-by: Martin Robinson <martinjrobins@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for MLIR-based expression evaluation
4 participants