Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sensitivities for Experiments #4415

Merged
merged 32 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
7e7f702
test: sensitivities with simulation class
martinjrobins Aug 7, 2024
04e1901
restrict input params for experiments
martinjrobins Aug 8, 2024
548f0c3
Merge branch 'develop' into i3834-sens-exp
martinjrobins Aug 8, 2024
550f8e9
hreinit of solver at hold stage is failing with sens
martinjrobins Aug 9, 2024
a448fab
a full experiment run works with idaklusolver
martinjrobins Sep 4, 2024
adec86e
move test to test_simulation_with_experiment
martinjrobins Sep 4, 2024
aa571c7
remove print stmts
martinjrobins Sep 4, 2024
6befcd9
fix Solution __add__ for empty dict and False
martinjrobins Sep 4, 2024
9fc2156
fix has_sensitivities
martinjrobins Sep 4, 2024
8eca323
merge develop
martinjrobins Sep 4, 2024
a9c50d3
fix sensitivities experiment solution test
martinjrobins Sep 4, 2024
e023868
Merge branch 'develop' into i3834-sens-exp
martinjrobins Sep 4, 2024
61fd5a7
fix sensitivity tests
martinjrobins Sep 6, 2024
106982b
fix casadi sens test
martinjrobins Sep 6, 2024
52fcc32
refactor solution to all_sensitivities to match other attributes
martinjrobins Sep 6, 2024
38f926d
fix processed variable for all_sensitivities
martinjrobins Sep 6, 2024
04af8f3
fix test on processed variable
martinjrobins Sep 6, 2024
cbbe610
coverage for same models within experiment
martinjrobins Sep 9, 2024
e692881
remove some unnecessary stuff in set_sens_initial_conditions_from
martinjrobins Sep 9, 2024
03b4e98
coverage for pybamm.Solution
martinjrobins Sep 9, 2024
95bcfa6
Merge branch 'develop' into i3834-sens-exp
martinjrobins Sep 9, 2024
a83882f
fix addition of timers in solution
martinjrobins Sep 9, 2024
72d542e
cover a few remaining lines
martinjrobins Sep 9, 2024
f8c384a
final line coverage
martinjrobins Sep 9, 2024
5f59e07
remove unneeded line
martinjrobins Sep 9, 2024
a519a54
add changelog
martinjrobins Sep 9, 2024
cba6b62
move changelog entry to unreleased
martinjrobins Sep 11, 2024
5528cd9
Merge branch 'develop' into i3834-sens-exp
kratman Sep 11, 2024
1c9438c
remove print and refactor try-except block
martinjrobins Sep 13, 2024
133cc8c
add warning for experiments with non-time-based step conditions
martinjrobins Sep 13, 2024
4a51241
Merge branch 'develop' into i3834-sens-exp
martinjrobins Sep 13, 2024
2e06da1
fix simulation test for pytest
martinjrobins Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# [Unreleased](https://github.com/pybamm-team/PyBaMM/)

## Features
- Added sensitivity calculation support for `pybamm.Simulation` and `pybamm.Experiment` ([#4415](https://github.com/pybamm-team/PyBaMM/pull/4415))

## Optimizations
- Removed the `start_step_offset` setting and disabled minimum `dt` warnings for drive cycles with the (`IDAKLUSolver`). ([#4416](https://github.com/pybamm-team/PyBaMM/pull/4416))

Expand Down
21 changes: 21 additions & 0 deletions src/pybamm/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,27 @@ def set_up_and_parameterise_experiment(self):
reduces simulation time since the model formulation is efficient.
"""
parameter_values = self._parameter_values.copy()

# some parameters are used to control the experiment, and should not be
# input parameters
restrict_list = {"Initial temperature [K]", "Ambient temperature [K]"}
for step in self.experiment.steps:
if issubclass(step.__class__, pybamm.experiment.step.BaseStepImplicit):
restrict_list.update(step.get_parameter_values([]).keys())
elif issubclass(step.__class__, pybamm.experiment.step.BaseStepExplicit):
restrict_list.update(["Current function [A]"])
for key in restrict_list:
try:
param = parameter_values[key]
if isinstance(param, pybamm.InputParameter):
raise pybamm.ModelError(
f"Cannot use '{key}' as an input parameter in this experiment. "
f"This experiment is controlled via the following parameters: {restrict_list}. "
f"None of these parameters are able to be input parameters."
)
except KeyError:
pass
martinjrobins marked this conversation as resolved.
Show resolved Hide resolved

# Set the initial temperature to be the temperature of the first step
# We can set this globally for all steps since any subsequent steps will either
# start at the temperature at the end of the previous step (if non-isothermal
Expand Down
149 changes: 130 additions & 19 deletions src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,33 @@ def calculate_consistent_state(self, model, time=0, inputs=None):
y0 = root_sol.all_ys[0]
return y0

def _solve_process_calculate_sensitivities_arg(
inputs, model, calculate_sensitivities
):
# get a list-only version of calculate_sensitivities
if isinstance(calculate_sensitivities, bool):
if calculate_sensitivities:
calculate_sensitivities_list = [p for p in inputs.keys()]
else:
calculate_sensitivities_list = []
else:
calculate_sensitivities_list = calculate_sensitivities

calculate_sensitivities_list.sort()
if not hasattr(model, "calculate_sensitivities"):
model.calculate_sensitivities = []

# Check that calculate_sensitivites have not been updated
sensitivities_have_changed = (
calculate_sensitivities_list != model.calculate_sensitivities
)

# save sensitivity parameters so we can identify them later on
# (FYI: this is used in the Solution class)
model.calculate_sensitivities = calculate_sensitivities_list

return calculate_sensitivities_list, sensitivities_have_changed

def solve(
self,
model,
Expand Down Expand Up @@ -722,15 +749,6 @@ def solve(
"""
pybamm.logger.info(f"Start solving {model.name} with {self.name}")

# get a list-only version of calculate_sensitivities
if isinstance(calculate_sensitivities, bool):
if calculate_sensitivities:
calculate_sensitivities_list = [p for p in inputs.keys()]
else:
calculate_sensitivities_list = []
else:
calculate_sensitivities_list = calculate_sensitivities

# Make sure model isn't empty
self._check_empty_model(model)

Expand Down Expand Up @@ -772,6 +790,12 @@ def solve(
self._set_up_model_inputs(model, inputs) for inputs in inputs_list
]

calculate_sensitivities_list, sensitivities_have_changed = (
BaseSolver._solve_process_calculate_sensitivities_arg(
model_inputs_list[0], model, calculate_sensitivities
)
)

# (Re-)calculate consistent initialization
# Assuming initial conditions do not depend on input parameters
# when len(inputs_list) > 1, only `model_inputs_list[0]`
Expand All @@ -792,13 +816,8 @@ def solve(
"for initial conditions."
)

# Check that calculate_sensitivites have not been updated
calculate_sensitivities_list.sort()
if hasattr(model, "calculate_sensitivities"):
model.calculate_sensitivities.sort()
else:
model.calculate_sensitivities = []
if calculate_sensitivities_list != model.calculate_sensitivities:
# if any setup configuration has changed, we need to re-set up
if sensitivities_have_changed:
self._model_set_up.pop(model, None)
# CasadiSolver caches its integrators using model, so delete this too
if isinstance(self, pybamm.CasadiSolver):
Expand Down Expand Up @@ -1066,6 +1085,58 @@ def _check_events_with_initialization(t_eval, model, inputs_dict):
f"Events {event_names} are non-positive at initial conditions"
)

def _set_sens_initial_conditions_from(
self, solution: pybamm.Solution, model: pybamm.BaseModel
) -> tuple:
"""
A restricted version of BaseModel.set_initial_conditions_from that only extracts the
sensitivities from a solution object, and only for a model that has been descretised.
This is used when setting the initial conditions for a sensitivity model.

Parameters
----------
solution : :class:`pybamm.Solution`
The solution to use to initialize the model

model: :class:`pybamm.BaseModel`
The model whose sensitivities to set

Returns
-------

initial_conditions : tuple of ndarray
The initial conditions for the sensitivities, each element of the tuple
corresponds to an input parameter
"""

ninputs = len(model.calculate_sensitivities)
initial_conditions = tuple([] for _ in range(ninputs))
solution = solution.last_state
for var in model.initial_conditions:
final_state = solution[var.name]
final_state = final_state.sensitivities
final_state_eval = tuple(
final_state[key] for key in model.calculate_sensitivities
)

scale, reference = var.scale.value, var.reference.value
for i in range(ninputs):
scaled_final_state_eval = (final_state_eval[i] - reference) / scale
initial_conditions[i].append(scaled_final_state_eval)

# Also update the concatenated initial conditions if the model is already
# discretised
# Unpack slices for sorting
y_slices = {var: slce for var, slce in model.y_slices.items()}
slices = [y_slices[symbol][0] for symbol in model.initial_conditions.keys()]

# sort equations according to slices
concatenated_initial_conditions = [
casadi.vertcat(*[eq for _, eq in sorted(zip(slices, init))])
for init in initial_conditions
]
return concatenated_initial_conditions

def process_t_interp(self, t_interp):
# set a variable for this
no_interp = (not self.supports_interp) and (
Expand All @@ -1092,6 +1163,7 @@ def step(
npts=None,
inputs=None,
save=True,
calculate_sensitivities=False,
t_interp=None,
):
"""
Expand All @@ -1117,6 +1189,11 @@ def step(
Any input parameters to pass to the model when solving
save : bool, optional
Save solution with all previous timesteps. Defaults to True.
calculate_sensitivities : list of str or bool, optional
Whether the solver calculates sensitivities of all input parameters. Defaults to False.
If only a subset of sensitivities are required, can also pass a
list of input parameter names

t_interp : None, list or ndarray, optional
The times (in seconds) at which to interpolate the solution. Defaults to None.
Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`).
Expand Down Expand Up @@ -1188,8 +1265,15 @@ def step(
# Set up inputs
model_inputs = self._set_up_model_inputs(model, inputs)

# process calculate_sensitivities argument
calculate_sensitivities_list, sensitivities_have_changed = (
BaseSolver._solve_process_calculate_sensitivities_arg(
model_inputs, model, calculate_sensitivities
)
)

first_step_this_model = model not in self._model_set_up
if first_step_this_model:
if first_step_this_model or sensitivities_have_changed:
if len(self._model_set_up) > 0:
existing_model = next(iter(self._model_set_up))
raise RuntimeError(
Expand All @@ -1208,18 +1292,45 @@ def step(
):
pybamm.logger.verbose(f"Start stepping {model.name} with {self.name}")

using_sensitivities = len(model.calculate_sensitivities) > 0

if isinstance(old_solution, pybamm.EmptySolution):
if not first_step_this_model:
# reset y0 to original initial conditions
self.set_up(model, model_inputs, ics_only=True)
elif old_solution.all_models[-1] == model:
# initialize with old solution
model.y0 = old_solution.all_ys[-1][:, -1]
last_state = old_solution.last_state
model.y0 = last_state.all_ys[0]
if using_sensitivities and isinstance(last_state._all_sensitivities, dict):
full_sens = last_state._all_sensitivities["all"][0]
model.y0S = tuple(full_sens[:, i] for i in range(full_sens.shape[1]))

else:
_, concatenated_initial_conditions = model.set_initial_conditions_from(
old_solution, return_type="ics"
)
model.y0 = concatenated_initial_conditions.evaluate(0, inputs=model_inputs)
if using_sensitivities:
model.y0S = self._set_sens_initial_conditions_from(old_solution, model)

# hopefully we'll get rid of explicit sensitivities soon so we can remove this
explicit_sensitivities = model.len_rhs_sens > 0 or model.len_alg_sens > 0
if (
explicit_sensitivities
and using_sensitivities
and not isinstance(old_solution, pybamm.EmptySolution)
and not old_solution.all_models[-1] == model
):
y0_list = []
if model.len_rhs > 0:
y0_list.append(model.y0[: model.len_rhs])
for s in model.y0S:
y0_list.append(s[: model.len_rhs])
if model.len_alg > 0:
y0_list.append(model.y0[model.len_rhs :])
for s in model.y0S:
y0_list.append(s[model.len_rhs :])
model.y0 = casadi.vertcat(*y0_list)

set_up_time = timer.time()

Expand Down
6 changes: 3 additions & 3 deletions src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ Solution IDAKLUSolverOpenMP<ExprSet>::solve(
}

if (sensitivity) {
CheckErrors(IDAGetSens(ida_mem, &t_val, yyS));
CheckErrors(IDAGetSensDky(ida_mem, t_val, 0, yyS));
}

// Store Consistent initialization
Expand Down Expand Up @@ -478,7 +478,7 @@ Solution IDAKLUSolverOpenMP<ExprSet>::solve(
bool hit_adaptive = save_adaptive_steps && retval == IDA_SUCCESS;

if (sensitivity) {
CheckErrors(IDAGetSens(ida_mem, &t_val, yyS));
CheckErrors(IDAGetSensDky(ida_mem, t_val, 0, yyS));
}

if (hit_tinterp) {
Expand All @@ -499,7 +499,7 @@ Solution IDAKLUSolverOpenMP<ExprSet>::solve(
// Reset the states and sensitivities at t = t_val
CheckErrors(IDAGetDky(ida_mem, t_val, 0, yy));
if (sensitivity) {
CheckErrors(IDAGetSens(ida_mem, &t_val, yyS));
CheckErrors(IDAGetSensDky(ida_mem, t_val, 0, yyS));
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/solvers/casadi_algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None):
model,
inputs_dict,
termination="final time",
sensitivities=explicit_sensitivities,
all_sensitivities=explicit_sensitivities,
)
sol.integration_time = integration_time
return sol
8 changes: 4 additions & 4 deletions src/pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None):
y0,
model,
inputs_dict,
sensitivities=False,
all_sensitivities=False,
)
solution.solve_time = 0
solution.integration_time = 0
Expand Down Expand Up @@ -478,7 +478,7 @@ def integer_bisect():
np.array([t_event]),
y_event[:, np.newaxis],
"event",
sensitivities=bool(model.calculate_sensitivities),
all_sensitivities=False,
)
solution.integration_time = (
coarse_solution.integration_time + dense_step_sol.integration_time
Expand Down Expand Up @@ -696,7 +696,7 @@ def _run_integrator(
y_sol,
model,
inputs_dict,
sensitivities=extract_sensitivities_in_solution,
all_sensitivities=extract_sensitivities_in_solution,
check_solution=False,
)
sol.integration_time = integration_time
Expand Down Expand Up @@ -736,7 +736,7 @@ def _run_integrator(
y_sol,
model,
inputs_dict,
sensitivities=extract_sensitivities_in_solution,
all_sensitivities=extract_sensitivities_in_solution,
check_solution=False,
)
sol.integration_time = integration_time
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None):
np.array([sol.t[-1]]),
np.transpose(y_event)[:, np.newaxis],
termination,
sensitivities=yS_out,
all_sensitivities=yS_out,
)
newsol.integration_time = integration_time
if not self.output_variables:
Expand Down
Loading
Loading