Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MLIR-based expression evaluation #3826

Closed
jsbrittain opened this issue Feb 19, 2024 · 3 comments · Fixed by #4199
Closed

Add support for MLIR-based expression evaluation #3826

jsbrittain opened this issue Feb 19, 2024 · 3 comments · Fixed by #4199
Assignees
Labels
difficulty: hard Will take several weeks feature priority: low No existing plans to resolve

Comments

@jsbrittain
Copy link
Contributor

Description

Provide an MLIR-based backend for expression evaluation that can be utilised by SUNDIALS for either CPU or GPU-accelerated solves.

Motivation

At present the PyBaMM IDAKLU solver makes use of SUNDIALS which provides support for GPU, but this support cannot be harnessed as IDAKLUs method of function evaluation relies on casadi, which would require user-side compilation of PyBaMM models. Of PyBaMMs solvers, currently only JAX (a python library) translates naturally to GPU. By providing an MLIR-based backend, the code required for expression evaluation can be lowered to CPU or GPU devices, allowing expression-trees to be passed for evaluation at runtime. There have been previous efforts to expose GPU support for SUNDIALS in PyBaMM [e.g. #2644 ], which stalled as the model equations could not be easily lowered to device (GPU) code.

Possible Implementation

MLIR is a multi-level process of lowering into LLVM's Intermediate Representation (IR), which can then be cross-compiled onto multiple platforms, including GPU. This is also the (broad) mechanism for how JAX supports multiple platforms.

Additional context

#3766 (comment)

@jsbrittain jsbrittain added feature difficulty: hard Will take several weeks priority: low No existing plans to resolve labels Feb 19, 2024
@jsbrittain jsbrittain self-assigned this Feb 19, 2024
@martinjrobins
Copy link
Contributor

After discussion with @jsbrittain, could possibly reuse the jax backend to do jax->hlo->xla->execute in C, using the c api for the xla compiler (https://github.com/openxla/xla/tree/main/xla/examples/axpy). Xla uses mlir anyway I think!

@martinjrobins
Copy link
Contributor

martinjrobins commented Feb 24, 2024

Problem is always sparsity :) openxla is planning to add but not there yet (https://github.com/openxla/stablehlo/blob/main/rfcs/20230210-sparsity.md)

Jax has sparse matrix support, so wonder what sort of hlo they write out for that?

@martinjrobins
Copy link
Contributor

Jax to hlo example https://jax.readthedocs.io/en/latest/aot.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
difficulty: hard Will take several weeks feature priority: low No existing plans to resolve
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants