Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds JAX IDAKLU solver integration #481

Draft
wants to merge 9 commits into
base: develop
Choose a base branch
from

Conversation

BradyPlanden
Copy link
Member

@BradyPlanden BradyPlanden commented Sep 2, 2024

Description

This PR adds the jaxified IDAKLU solver to enable autodiff for the cost and likelihood classes. At the moment the IDAKLU solver is limited to first order sensitivity information and as such we are limited to gradient information from the autodiff cost/likelihood classes.

As an example of how to use the jaxified IDAKLU, an experimental subdirectory is added with the JaxSumSquaredError and JaxLogNormalLikelihood classes. These classes only required the evaluate method to be defined, with jax's value_and_grad method to capture the gradient information. Currently, this solver matches the casadi fast with events solver in most cases, with greatly improved performance in computing sensitivities. This performance is expected to improve even more with the next PyBaMM release.

This also opens up future functionality for gradient based optimisers in design optimisation of non-geometric parameters, as autodiff can provide gradients for any constructed cost/likelihood/design function.

To Do

  • Add Tests

Issue reference

Fixes # (issue-number)

Review

Before you mark your PR as ready for review, please ensure that you've considered the following:

  • Updated the CHANGELOG.md in reverse chronological order (newest at the top) with a concise description of the changes, including the PR number.
  • Noted any breaking changes, including details on how it might impact existing functionality.

Type of change

  • New Feature: A non-breaking change that adds new functionality.
  • Optimization: A code change that improves performance.
  • Examples: A change to existing or additional examples.
  • Bug Fix: A non-breaking change that addresses an issue.
  • Documentation: Updates to documentation or new documentation for new features.
  • Refactoring: Non-functional changes that improve the codebase.
  • Style: Non-functional changes related to code style (formatting, naming, etc).
  • Testing: Additional tests to improve coverage or confirm functionality.
  • Other: (Insert description of change)

Key checklist:

  • No style issues: $ pre-commit run (or $ nox -s pre-commit) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)
  • All unit tests pass: $ nox -s tests
  • The documentation builds: $ nox -s doctest

You can run integration tests, unit tests, and doctests together at once, using $ nox -s quick.

Further checks:

  • Code is well-commented, especially in complex or unclear areas.
  • Added tests that prove my fix is effective or that my feature works.
  • Checked that coverage remains or improves, and added tests if necessary to maintain or increase coverage.

Thank you for contributing to our project! Your efforts help us to deliver great software.

Copy link

codecov bot commented Sep 2, 2024

Codecov Report

Attention: Patch coverage is 48.14815% with 42 lines in your changes missing coverage. Please review.

Project coverage is 97.70%. Comparing base (7a140bb) to head (92265b6).
Report is 1 commits behind head on develop.

Files with missing lines Patch % Lines
pybop/experimental/jax_costs.py 37.50% 30 Missing ⚠️
pybop/models/base_model.py 40.00% 9 Missing ⚠️
pybop/problems/base_problem.py 66.66% 2 Missing ⚠️
pybop/problems/fitting_problem.py 90.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop     #481      +/-   ##
===========================================
- Coverage    98.82%   97.70%   -1.13%     
===========================================
  Files           52       53       +1     
  Lines         3582     3656      +74     
===========================================
+ Hits          3540     3572      +32     
- Misses          42       84      +42     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@BradyPlanden
Copy link
Member Author

Here's the benchmark script for the solvers in this PR. A version of it is also in the PR.

import time
import numpy as np
import pybamm

import pybop

n = 50  # Number of solves
solvers = [
    pybamm.CasadiSolver(mode="fast with events", atol=1e-6, rtol=1e-6),
    pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6),
]

# Parameter set and model definition
parameter_set = pybop.ParameterSet.pybamm("Chen2020")
model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=solvers[0])

# Fitting parameters
parameters = pybop.Parameters(
    pybop.Parameter(
        "Negative electrode active material volume fraction", initial_value=0.55
    ),
    pybop.Parameter(
        "Positive electrode active material volume fraction", initial_value=0.55
    ),
)

# Define test protocol and generate data
t_eval = np.linspace(0, 100, 1000)
values = model.predict(
    initial_state={"Initial open-circuit voltage [V]": 4.2}, t_eval=t_eval
)

# Form dataset
dataset = pybop.Dataset(
    {
        "Time [s]": values["Time [s]"].data,
        "Current function [A]": values["Current [A]"].data,
        "Voltage [V]": values["Voltage [V]"].data,
    }
)


# Create inputs function for benchmarking
def inputs():
    return {
        "Negative electrode active material volume fraction": 0.55
        + np.random.normal(0, 0.01),
        "Positive electrode active material volume fraction": 0.55
        + np.random.normal(0, 0.01),
    }


# Iterate over the solvers and print benchmarks
for solver in solvers:
    # Setup Fitting Problem
    model.solver = solver
    problem = pybop.FittingProblem(model, parameters, dataset)
    cost = pybop.SumSquaredError(problem)

    start_time = time.time()
    for _i in range(n):
        out = problem.model.simulate(inputs=inputs(), t_eval=t_eval)
    print(f"({solver.name}) Time model.simulate: {time.time() - start_time:.4f}")

    start_time = time.time()
    for _i in range(n):
        out = problem.model.simulateS1(inputs=inputs(), t_eval=t_eval)
    print(f"({solver.name}) Time model.SimulateS1: {time.time() - start_time:.4f}")

    start_time = time.time()
    for _i in range(n):
        out = problem.evaluate(inputs=inputs())
    print(f"({solver.name}) Time problem.evaluate: {time.time() - start_time:.4f}")

    start_time = time.time()
    for _i in range(n):
        out = problem.evaluateS1(inputs=inputs())
    print(f"({solver.name}) Time Problem.EvaluateS1: {time.time() - start_time:.4f}")

    start_time = time.time()
    for _i in range(n):
        out = cost(inputs(), calculate_grad=False)
    print(f"({solver.name}) Time PyBOP Cost w/o grad: {time.time() - start_time:.4f}")

    start_time = time.time()
    for _i in range(n):
        out = cost(inputs(), calculate_grad=True)
    print(f"({solver.name}) Time PyBOP Cost w/grad: {time.time() - start_time:.4f}")

# Recreate for Jax IDAKLU solver
ida_solver =pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6)
model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=ida_solver, jax=True)
problem = pybop.FittingProblem(model, parameters, dataset)
cost = pybop.JaxSumSquaredError(problem)

start_time = time.time()
for _i in range(n):
    out = cost(inputs(), calculate_grad=False)
print(f"Time Jax SumSquaredError w/o grad: {time.time() - start_time:.4f}")

start_time = time.time()
for _i in range(n):
    out = cost(inputs(), calculate_grad=True)
print(f"Time Jax SumSquaredError w/ grad: {time.time() - start_time:.4f}")

which produces the following on my M3 Pro Macbook:

(CasADi solver with 'fast with events' mode) Time model.simulate: 3.2579
(CasADi solver with 'fast with events' mode) Time model.SimulateS1: 13.5679
(CasADi solver with 'fast with events' mode) Time problem.evaluate: 6.8836
(CasADi solver with 'fast with events' mode) Time Problem.EvaluateS1: 152.7627
(CasADi solver with 'fast with events' mode) Time PyBOP Cost w/o grad: 7.1857
(CasADi solver with 'fast with events' mode) Time PyBOP Cost w/grad: 155.6699
(IDA KLU solver) Time model.simulate: 6.5524
(IDA KLU solver) Time model.SimulateS1: 17.9455
(IDA KLU solver) Time problem.evaluate: 6.6003
(IDA KLU solver) Time Problem.EvaluateS1: 18.0940
(IDA KLU solver) Time PyBOP Cost w/o grad: 6.5335
(IDA KLU solver) Time PyBOP Cost w/grad: 18.1650
Time Jax SumSquaredError w/o grad: 6.9650
Time Jax SumSquaredError w/ grad: 19.5255

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

1 participant