-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds JAX IDAKLU solver integration #481
base: develop
Are you sure you want to change the base?
Conversation
…benchmarks. Adds JaxSumSquaredError and JaxLogNormalLikelihood.
# Conflicts: # pyproject.toml
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #481 +/- ##
===========================================
- Coverage 98.82% 97.70% -1.13%
===========================================
Files 52 53 +1
Lines 3582 3656 +74
===========================================
+ Hits 3540 3572 +32
- Misses 42 84 +42 ☔ View full report in Codecov by Sentry. |
Here's the benchmark script for the solvers in this PR. A version of it is also in the PR. import time
import numpy as np
import pybamm
import pybop
n = 50 # Number of solves
solvers = [
pybamm.CasadiSolver(mode="fast with events", atol=1e-6, rtol=1e-6),
pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6),
]
# Parameter set and model definition
parameter_set = pybop.ParameterSet.pybamm("Chen2020")
model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=solvers[0])
# Fitting parameters
parameters = pybop.Parameters(
pybop.Parameter(
"Negative electrode active material volume fraction", initial_value=0.55
),
pybop.Parameter(
"Positive electrode active material volume fraction", initial_value=0.55
),
)
# Define test protocol and generate data
t_eval = np.linspace(0, 100, 1000)
values = model.predict(
initial_state={"Initial open-circuit voltage [V]": 4.2}, t_eval=t_eval
)
# Form dataset
dataset = pybop.Dataset(
{
"Time [s]": values["Time [s]"].data,
"Current function [A]": values["Current [A]"].data,
"Voltage [V]": values["Voltage [V]"].data,
}
)
# Create inputs function for benchmarking
def inputs():
return {
"Negative electrode active material volume fraction": 0.55
+ np.random.normal(0, 0.01),
"Positive electrode active material volume fraction": 0.55
+ np.random.normal(0, 0.01),
}
# Iterate over the solvers and print benchmarks
for solver in solvers:
# Setup Fitting Problem
model.solver = solver
problem = pybop.FittingProblem(model, parameters, dataset)
cost = pybop.SumSquaredError(problem)
start_time = time.time()
for _i in range(n):
out = problem.model.simulate(inputs=inputs(), t_eval=t_eval)
print(f"({solver.name}) Time model.simulate: {time.time() - start_time:.4f}")
start_time = time.time()
for _i in range(n):
out = problem.model.simulateS1(inputs=inputs(), t_eval=t_eval)
print(f"({solver.name}) Time model.SimulateS1: {time.time() - start_time:.4f}")
start_time = time.time()
for _i in range(n):
out = problem.evaluate(inputs=inputs())
print(f"({solver.name}) Time problem.evaluate: {time.time() - start_time:.4f}")
start_time = time.time()
for _i in range(n):
out = problem.evaluateS1(inputs=inputs())
print(f"({solver.name}) Time Problem.EvaluateS1: {time.time() - start_time:.4f}")
start_time = time.time()
for _i in range(n):
out = cost(inputs(), calculate_grad=False)
print(f"({solver.name}) Time PyBOP Cost w/o grad: {time.time() - start_time:.4f}")
start_time = time.time()
for _i in range(n):
out = cost(inputs(), calculate_grad=True)
print(f"({solver.name}) Time PyBOP Cost w/grad: {time.time() - start_time:.4f}")
# Recreate for Jax IDAKLU solver
ida_solver =pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6)
model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=ida_solver, jax=True)
problem = pybop.FittingProblem(model, parameters, dataset)
cost = pybop.JaxSumSquaredError(problem)
start_time = time.time()
for _i in range(n):
out = cost(inputs(), calculate_grad=False)
print(f"Time Jax SumSquaredError w/o grad: {time.time() - start_time:.4f}")
start_time = time.time()
for _i in range(n):
out = cost(inputs(), calculate_grad=True)
print(f"Time Jax SumSquaredError w/ grad: {time.time() - start_time:.4f}") which produces the following on my M3 Pro Macbook:
|
Description
This PR adds the jaxified IDAKLU solver to enable autodiff for the cost and likelihood classes. At the moment the IDAKLU solver is limited to first order sensitivity information and as such we are limited to gradient information from the autodiff cost/likelihood classes.
As an example of how to use the jaxified IDAKLU, an
experimental
subdirectory is added with theJaxSumSquaredError
andJaxLogNormalLikelihood
classes. These classes only required theevaluate
method to be defined, with jax'svalue_and_grad
method to capture the gradient information. Currently, this solver matches thecasadi fast with events
solver in most cases, with greatly improved performance in computing sensitivities. This performance is expected to improve even more with the next PyBaMM release.This also opens up future functionality for gradient based optimisers in design optimisation of non-geometric parameters, as autodiff can provide gradients for any constructed cost/likelihood/design function.
To Do
Issue reference
Fixes # (issue-number)
Review
Before you mark your PR as ready for review, please ensure that you've considered the following:
Type of change
Key checklist:
$ pre-commit run
(or$ nox -s pre-commit
) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)$ nox -s tests
$ nox -s doctest
You can run integration tests, unit tests, and doctests together at once, using
$ nox -s quick
.Further checks:
Thank you for contributing to our project! Your efforts help us to deliver great software.