Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance refactor for Jax BDF Solver, fixes #4455 #4456

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from

Conversation

BradyPlanden
Copy link
Member

@BradyPlanden BradyPlanden commented Sep 23, 2024

Description

This PR refactors the JAX BDF solver for performance updates. It also updates the JaxSolver's default method to be "BDF" as performance is much higher than "RK45". This PR also bug fixes the calculate_sensitivities error described in #4455. Lastly, this PR adds JAX vectorised example script to showcase the JaxSolvers performance in highly vectorised usage.

Fixes #4455

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.

  • New feature (non-breaking change which adds functionality)
  • Optimization (back-end change that speeds up the code)
  • Bug fix (non-breaking change which fixes an issue)

Key checklist:

  • No style issues: $ pre-commit run (or $ nox -s pre-commit) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)
  • All tests pass: $ python run-tests.py --all (or $ nox -s tests)
  • The documentation builds: $ python run-tests.py --doctest (or $ nox -s doctests)

You can run integration tests, unit tests, and doctests together at once, using $ python run-tests.py --quick (or $ nox -s quick).

Further checks:

  • Code is commented, particularly in hard-to-understand areas
  • Tests added that prove fix is effective or that feature works

…fixes for calculate_sensitivities, adds JAX vectorised example
Copy link

codecov bot commented Sep 23, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.46%. Comparing base (d362c98) to head (9f0cdff).

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #4456      +/-   ##
===========================================
- Coverage    99.46%   99.46%   -0.01%     
===========================================
  Files          293      293              
  Lines        22384    22321      -63     
===========================================
- Hits         22264    22201      -63     
  Misses         120      120              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@BradyPlanden
Copy link
Member Author

BradyPlanden commented Sep 23, 2024

Benchmarking for the BDF method is provided below. This was ran on a MacBook Pro M3Pro laptop, with the DFN discretised with var_pts = {"x_n": 10, "x_s": 10, "x_p": 10, "r_n": 10, "r_p": 10}. For the RK45 method, the SPMe had not solved within 5 minutes (conservatively estimated), so I did not proceed to the DFN.

Unless otherwise stated, the results below are in seconds. This is the average across ten simulations of the example script (with the number of simulations reduced from 1000 to 100), including the JIT compilation time.

Model develop (v24.9) #4456 RK45
SPM 0.97 0.85 1.11
SPMe 3.09 1.82 >5min
DFN 18.86 8.29 ≫5min

@martinjrobins
Copy link
Contributor

Great stuff, thanks @BradyPlanden. I'll have a closer look at the code in a tic, I'm really suprised how the timings have evolved, when I wrote this the jit compilation of the BDF solver was very slow compared with RK45, hence the RK45 default. But it looks like there have been many changes in JAX since then. Do you have timings including the solve by itself (without including JIT)?

@BradyPlanden
Copy link
Member Author

BradyPlanden commented Sep 23, 2024

I was also a bit surprised, as I didn't go into this refactor looking for performance, mostly just to clean up the code and understand the underlying JAX methods. I suspect there are areas for further improvement. Here are the post-JIT timings per 100 solves, now in milliseconds.

Model develop (v24.9) #4456 RK45
SPM 154 39.5 889
SPMe 1620 320 N/A
DFN 12372 2340 N/A

The JIT compilation for the BDF has improved quite a bit, there are a few areas I think it's most likely to come from. I think reducing the memory copies and using newer JAX methods (lax.cond, etc.) is doing a lot of the heavier lifting here.

Edit: it would be interesting to compare these numbers to the recent idaklu parallelisation

Copy link
Contributor

@martinjrobins martinjrobins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, thanks @BradyPlanden, glad to see the BDF jit compiling so much faster now :) Interested to see the runtime differences to idaklu.

calculate_sensitivities_explicit = (
model.calculate_sensitivities and not isinstance(self, pybamm.IDAKLUSolver)
model.calculate_sensitivities
and not isinstance(self, (pybamm.IDAKLUSolver, pybamm.JaxSolver))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've started using properties to encode solver features (e..g see supports_parallel_solve). Can you make this a property of the solver so you can just say model.calculate_sensitivities and not self.requires_explicit_sensitivities()

Copy link
Member Author

@BradyPlanden BradyPlanden Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll take a look. I was thinking about raising an error here instead. As the calculate_sensitivities argument is essentially ignored by the pure Jax solvers. Do you have an opinion either way?


# Run solve for all inputs
start_time = time.time()
solution = solver.solve(model, t_eval, inputs=inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here you should describe in a comment that JIT compilation occurs during this first solver, and demonstrate that subsequent solves are faster. I think this would be an important concept to get across to a user.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: JAX BDF solver sensitivities bug
2 participants