-
-
Notifications
You must be signed in to change notification settings - Fork 529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance refactor for Jax BDF Solver, fixes #4455 #4456
base: develop
Are you sure you want to change the base?
Performance refactor for Jax BDF Solver, fixes #4455 #4456
Conversation
…fixes for calculate_sensitivities, adds JAX vectorised example
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #4456 +/- ##
===========================================
- Coverage 99.46% 99.46% -0.01%
===========================================
Files 293 293
Lines 22384 22321 -63
===========================================
- Hits 22264 22201 -63
Misses 120 120 ☔ View full report in Codecov by Sentry. |
Benchmarking for the BDF method is provided below. This was ran on a MacBook Pro M3Pro laptop, with the DFN discretised with Unless otherwise stated, the results below are in seconds. This is the average across ten simulations of the example script (with the number of simulations reduced from 1000 to 100), including the JIT compilation time.
|
Great stuff, thanks @BradyPlanden. I'll have a closer look at the code in a tic, I'm really suprised how the timings have evolved, when I wrote this the jit compilation of the BDF solver was very slow compared with RK45, hence the RK45 default. But it looks like there have been many changes in JAX since then. Do you have timings including the solve by itself (without including JIT)? |
I was also a bit surprised, as I didn't go into this refactor looking for performance, mostly just to clean up the code and understand the underlying JAX methods. I suspect there are areas for further improvement. Here are the post-JIT timings per 100 solves, now in milliseconds.
The JIT compilation for the BDF has improved quite a bit, there are a few areas I think it's most likely to come from. I think reducing the memory copies and using newer JAX methods ( Edit: it would be interesting to compare these numbers to the recent idaklu parallelisation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great, thanks @BradyPlanden, glad to see the BDF jit compiling so much faster now :) Interested to see the runtime differences to idaklu.
calculate_sensitivities_explicit = ( | ||
model.calculate_sensitivities and not isinstance(self, pybamm.IDAKLUSolver) | ||
model.calculate_sensitivities | ||
and not isinstance(self, (pybamm.IDAKLUSolver, pybamm.JaxSolver)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've started using properties to encode solver features (e..g see supports_parallel_solve
). Can you make this a property of the solver so you can just say model.calculate_sensitivities and not self.requires_explicit_sensitivities()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'll take a look. I was thinking about raising an error here instead. As the calculate_sensitivities
argument is essentially ignored by the pure Jax solvers. Do you have an opinion either way?
|
||
# Run solve for all inputs | ||
start_time = time.time() | ||
solution = solver.solve(model, t_eval, inputs=inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here you should describe in a comment that JIT compilation occurs during this first solver, and demonstrate that subsequent solves are faster. I think this would be an important concept to get across to a user.
Description
This PR refactors the JAX BDF solver for performance updates. It also updates the JaxSolver's default method to be
"BDF"
as performance is much higher than"RK45"
. This PR also bug fixes thecalculate_sensitivities
error described in #4455. Lastly, this PR adds JAX vectorised example script to showcase the JaxSolvers performance in highly vectorised usage.Fixes #4455
Type of change
Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.
Key checklist:
$ pre-commit run
(or$ nox -s pre-commit
) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)$ python run-tests.py --all
(or$ nox -s tests
)$ python run-tests.py --doctest
(or$ nox -s doctests
)You can run integration tests, unit tests, and doctests together at once, using
$ python run-tests.py --quick
(or$ nox -s quick
).Further checks: