You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Provide an MLIR-based backend for expression evaluation that can be utilised by SUNDIALS for either CPU or GPU-accelerated solves.
Motivation
At present the PyBaMM IDAKLU solver makes use of SUNDIALS which provides support for GPU, but this support cannot be harnessed as IDAKLUs method of function evaluation relies on casadi, which would require user-side compilation of PyBaMM models. Of PyBaMMs solvers, currently only JAX (a python library) translates naturally to GPU. By providing an MLIR-based backend, the code required for expression evaluation can be lowered to CPU or GPU devices, allowing expression-trees to be passed for evaluation at runtime. There have been previous efforts to expose GPU support for SUNDIALS in PyBaMM [e.g. #2644 ], which stalled as the model equations could not be easily lowered to device (GPU) code.
Possible Implementation
MLIR is a multi-level process of lowering into LLVM's Intermediate Representation (IR), which can then be cross-compiled onto multiple platforms, including GPU. This is also the (broad) mechanism for how JAX supports multiple platforms.
Description
Provide an MLIR-based backend for expression evaluation that can be utilised by SUNDIALS for either CPU or GPU-accelerated solves.
Motivation
At present the PyBaMM IDAKLU solver makes use of SUNDIALS which provides support for GPU, but this support cannot be harnessed as IDAKLUs method of function evaluation relies on casadi, which would require user-side compilation of PyBaMM models. Of PyBaMMs solvers, currently only JAX (a python library) translates naturally to GPU. By providing an MLIR-based backend, the code required for expression evaluation can be lowered to CPU or GPU devices, allowing expression-trees to be passed for evaluation at runtime. There have been previous efforts to expose GPU support for SUNDIALS in PyBaMM [e.g. #2644 ], which stalled as the model equations could not be easily lowered to device (GPU) code.
Possible Implementation
MLIR is a multi-level process of lowering into LLVM's Intermediate Representation (IR), which can then be cross-compiled onto multiple platforms, including GPU. This is also the (broad) mechanism for how JAX supports multiple platforms.
Additional context
#3766 (comment)
The text was updated successfully, but these errors were encountered: