-
-
Notifications
You must be signed in to change notification settings - Fork 546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for MLIR-based expression evaluation #4199
Conversation
This amends the tag for the CodeCov GitHub Action from `4.1.0` to `v4.1.0`. This was a Dependabot error
…e. support more than just casadi functions]
… to test on another machine
…ecution framework)
Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @jsbrittain this looks excellent. I've made a few suggestions below, see what you think. also, is jax_evaluator="iree"
needed for the options, it there a case where you want to convert to jax but not use the jax_evaluator?
@@ -310,7 +310,7 @@ def find_symbols( | |||
|
|||
elif isinstance(symbol, pybamm.SparseStack): | |||
if len(children_vars) == 1: | |||
symbol_str = children_vars[0] | |||
symbol_str = children_vars[0] # pragma: no cover |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again a test would be good for this. if you don't want to do it in this PR (understandable) then an issue for reminder would be good too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more below which I won't mention. Looks like we were not covering a lot of it :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quite a few of these were introduced as 'indirect' changes during one of the commits. I will go back and see whether these can be relaxed again, or if new tests are needed...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've gone through and removed all of the 'no cover' statements that I inserted due to indirect changes in previous revisions. Codecov seems to have settled itself down now and is passing everything again. So we should be good.
pybamm/solvers/c_solvers/idaklu/Expressions/Base/Expression.hpp
Outdated
Show resolved
Hide resolved
} | ||
|
||
// Parse module name | ||
std::regex module_name_regex("module @([^\\s]+)"); // Match until first whitespace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does iree have a parser class that can you use to do this? (e.g. https://mlir.llvm.org/doxygen/classmlir_1_1detail_1_1Parser.html). This seems quite manual
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is proving surprisingly difficult to implement. I can see two options:
- working with IREE: the routines provided only seem to permit access to low-level memory details, using representations below the
tensor
syntax, which makes extraction difficult. - working with llvm-mlir: we have to compile-in the llvm/mlir libraries alongside iree; set-up new contexts and register dialects in order to parse the mlir string. This approach has potential but is so far proving non-trivial to get working in harmony with the iree solution we already have in place.
As this is turning into a bit of a blocker, I have decided to refactor the existing module parsing code out into it's own ModuleParser class that can be modified or wholesale replaced as needed at a later date.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fair enough. Happy to go with the separate parsing class for now.
pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp
Outdated
Show resolved
Hide resolved
Co-authored-by: Martin Robinson <martinjrobins@gmail.com>
….cpp Co-authored-by: Martin Robinson <martinjrobins@gmail.com>
@martinjrobins yes, there is actually an existing python-idaklu interface that will run if we don't redirect using the (new) |
that reminds me, we should get rid of python-idaklu, I don't think it serves a useful purpose anymore. I'll add an issue |
Hi again, @jsbrittain – #4205 is trying to migrate PyBaMM's package structure to an |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jsbrittain for the changes, all looks good to me now. I'm happy to merge, but just need to check with #4249 how we're going to put these two PRs in.
* Fix CodeCov GHA workflow failure (pybamm-team#3845) This amends the tag for the CodeCov GitHub Action from `4.1.0` to `v4.1.0`. This was a Dependabot error * Begin refactor IDAKLU solver to support generalisable expressions [i.e. support more than just casadi functions] * Continue refactor; compiles but with dlopen error on load; committing to test on another machine * Refactor: Introduce Expression and ExpressionSet classes * Restructure Expressions with generics folder and implementation-specific subfolders * Separate Expression classes * Template Expression class * WIP: Subclass expressions (remove templates in order to generalise execution framework) * Subclass expressions and remove unnecessary template arguments * Isolate casadi functionality to Expression subclasses * Add IREE expressions class * Remove breakpoints * Map input arguments to (reduced) call signature in MLIR * Add support for inputs * Support sensitivities * Pre-commit * Support output_variables * Fix designator order error on linux; remove reference to ninja * OS-invariant library loading * Convert some pointer arrays to vectors; fixes sporadic crashes * Fix bad memory allocations * Tidy-up code * Tidy up code * Resolve jax/iree version numbers * Conditional compilation of iree code * Fix compiler variables * Update noxfile sources list * Pre-commit * Make IREE tests conditional on IREE install * Enable IREE in CI unit/integration testing * Make demotion optional in idaklu_solver.py (still unsupported by IDAKLU) * style: pre-commit fixes * Fix expression tree test given change to bracketed expressions * Codacy corrections and suppressions * Enable IREE in unit testing * style: pre-commit fixes * Enable IREE in coverage testing * Restrict IREE install to supported MacOS/Python versions * Restrict IREE supported macOS installs * Additional tests for IREE (demotion, output_variables, sensititivies) * style: pre-commit fixes * Additional tests (improve test coverage) * Fix IREE unit test * Fix sensitivities-sparsity bug; improve tests * Fix tests * Improve coverage * Improve coverage (indirect) * Suppress IREE warning on reload * style: pre-commit fixes * Fix codacy warning * Fix noxfile docstring * Update noxfile.py Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> * Update pyproject.toml Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> * style: pre-commit fixes * Update pybamm/solvers/c_solvers/idaklu.cpp Co-authored-by: Martin Robinson <martinjrobins@gmail.com> * Update pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp Co-authored-by: Martin Robinson <martinjrobins@gmail.com> * Try removing no-cover from expression_tree/function.py * Remove C from class name * Clarify Base Expression docstrings * Remove build-time iree.compiler search in CMakeLists.txt * Add install note to iree-compiler in pyproject.toml * Add IREE dependencies to docs * style: pre-commit fixes * Refactor MLIR parsing into ModuleParser class * style: pre-commit fixes * Add codacy hints * Codacy fix * Remove no-cover statements * style: pre-commit fixes * Coverage fix --------- Co-authored-by: Ferran Brosa Planella <Ferran.Brosa-Planella@warwick.ac.uk> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Co-authored-by: Martin Robinson <martinjrobins@gmail.com>
Description
Add a new expression evaluation backend to the IDAKLU solver. MLIR expression evaluation is now supported by lowering PyBaMM's Jax-based expressions into MLIR, which are then compiled and executed as part of the IDAKLU solver using IREE.
To enable the IREE/MLIR backend, set the (new)
PYBAMM_IDAKLU_EXPR_IREE
compiler flagON
via an environment variable and install PyBaMM using the developer method (by defaultPYBAMM_IDAKLU_EXPR_IREE
is turnedOFF
):Expression evaluation in IDAKLU is enabled by constructing the model using Jax expressions (
model.convert_to_format="jax"
) and setting the solver backend (jax_evaluator="iree"
). Example:Note that IREE currently only supports single-precision floating-point operations, which requires the model to be demoted from 64-bit to 32-bit precision before the solver can run. This is handled within the solver logic, but the operation is performed in-place on the PyBaMM battery model (we display a warning when run). Operating at lower precision requires tolerances to be relaxed for convergence on larger [e.g. DFN] models, and leads to memory transfers and type casting in the solver which are currently causing slow-downs (at least until 64-bit computation is natively supported).
Comparative performance on the above SPM problem on an Apple M2 Macbook Pro (with events=[] to allow comparison to the JaxSolver):
Substituting a DFN model (and reducing
atol = 1e-1
) the times become:There is a noticeable performance deficit for the IDAKLU-MLIR solver compared to Casadi, due to 1) initial compilation of MLIR to bytecode, 2) demotion strategies, and 3) memory transfers casting between types in the solver. We anticipate improvements in the second and third points with native 64-bit IREE support, and as our IREE approach compiles on the model expressions (not the solver) compilation times quickly out-perform the JaxSolver with increasing model complexity / time steps (while also taking full advantage of the capabilities already provided by the IDAKLU solver, such as events). The IREE/MLIR approach offers a pathway to compiling expressions across a wide variety of backends, including metal and cuda, although additional code adjustment (principally host/device transfers) will be required before those can be supported.
Resolves #3826
Type of change
Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.
Key checklist:
$ pre-commit run
(or$ nox -s pre-commit
) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)$ python run-tests.py --all
(or$ nox -s tests
)$ python run-tests.py --doctest
(or$ nox -s doctests
)You can run integration tests, unit tests, and doctests together at once, using
$ python run-tests.py --quick
(or$ nox -s quick
).Further checks: