Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IDA adaptive time stepping #4351

Merged
merged 15 commits into from
Aug 23, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

## Optimizations

- Improved adaptive time-stepping performance of the (`IDAKLUSolver`). ([#4351](https://github.com/pybamm-team/PyBaMM/pull/4351))
- Improved performance and reliability of DAE consistent initialization. ([#4301](https://github.com/pybamm-team/PyBaMM/pull/4301))
- Replaced rounded Faraday constant with its exact value in `bpx.py` for better comparison between different tools. ([#4290](https://github.com/pybamm-team/PyBaMM/pull/4290))

Expand Down
1 change: 1 addition & 0 deletions src/pybamm/batch_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def solve(
calc_esoh=True,
starting_solution=None,
initial_soc=None,
t_interp=None,
**kwargs,
):
"""
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Experiment:
def __init__(
self,
operating_conditions: list[str | tuple[str]],
period: str = "1 minute",
period: str | None = None,
temperature: float | None = None,
termination: list[str] | None = None,
):
Expand Down
78 changes: 78 additions & 0 deletions src/pybamm/experiment/step/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,84 @@ def default_duration(self, value):
else:
return 24 * 3600 # one day in seconds

@staticmethod
def default_period():
return 60.0 # seconds

def default_time_vector(self, tf, t0=0):
if self.period is None:
period = self.default_period()
else:
period = self.period
npts = max(int(round(np.abs(tf - t0) / period)) + 1, 2)

return np.linspace(t0, tf, npts)

def setup_timestepping(self, solver, tf, t_interp=None):
"""
Setup timestepping for the model.

Parameters
----------
solver: :class`pybamm.BaseSolver`
The solver
tf: float
The final time
t_interp: np.array | None
The time points at which to interpolate the solution
"""
if solver._supports_interp:
MarcBerliner marked this conversation as resolved.
Show resolved Hide resolved
return self._setup_timestepping(solver, tf, t_interp)
else:
return self._setup_timestepping_dense_t_eval(solver, tf, t_interp)

def _setup_timestepping(self, solver, tf, t_interp):
"""
Setup timestepping for the model. This returns a t_eval vector that stops
only at the first and last time points. If t_interp and the period are
unspecified, then the solver will use adaptive time-stepping. For a given
period, t_interp willbe set to return the solution at the end of each period
MarcBerliner marked this conversation as resolved.
Show resolved Hide resolved
and at the final time.

Parameters
----------
solver: :class`pybamm.BaseSolver`
The solver
tf: float
The final time
t_interp: np.array | None
The time points at which to interpolate the solution
"""
t_eval = np.array([0, tf])
if t_interp is None:
if self.period is not None:
t_interp = self.default_time_vector(tf)
else:
t_interp = solver.process_t_interp(t_interp)

return t_eval, t_interp

def _setup_timestepping_dense_t_eval(self, solver, tf, t_interp):
"""
Setup timestepping for the model. By default, this returns a dense t_eval which
stops the solver at each point in the t_eval vector. This method is for solvers
that do not support intra-solve interpolation for the solution.

Parameters
----------
solver: :class`pybamm.BaseSolver`
The solver
tf: float
The final time
t_interp: np.array | None
The time points at which to interpolate the solution
"""
t_eval = self.default_time_vector(tf)

t_interp = solver.process_t_interp(t_interp)

return t_eval, t_interp

def process_model(self, model, parameter_values):
new_model = model.new_copy()
new_parameter_values = parameter_values.copy()
Expand Down
12 changes: 10 additions & 2 deletions src/pybamm/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def solve(
callbacks=None,
showprogress=False,
inputs=None,
t_interp=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -400,6 +401,9 @@ def solve(
Whether to show a progress bar for cycling. If true, shows a progress bar
for cycles. Has no effect when not used with an experiment.
Default is False.
t_interp : None, list or ndarray, optional
MarcBerliner marked this conversation as resolved.
Show resolved Hide resolved
The times (in seconds) at which to interpolate the solution. Defaults to None.
Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`).
**kwargs
Additional key-word arguments passed to `solver.solve`.
See :meth:`pybamm.BaseSolver.solve`.
Expand Down Expand Up @@ -687,13 +691,17 @@ def solve(
"start time": start_time,
}
# Make sure we take at least 2 timesteps
npts = max(int(round(dt / step.period)) + 1, 2)
t_eval, t_interp_processed = step.setup_timestepping(
solver, dt, t_interp
)

try:
step_solution = solver.step(
current_solution,
model,
dt,
t_eval=np.linspace(0, dt, npts),
t_eval,
t_interp=t_interp_processed,
save=False,
inputs=inputs,
**kwargs,
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/solvers/algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def tol(self):
def tol(self, value):
self._tol = value

def _integrate(self, model, t_eval, inputs_dict=None):
def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None):
"""
Calculate the solution of the algebraic equations through root-finding

Expand Down
40 changes: 37 additions & 3 deletions src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
self.name = "Base solver"
self.ode_solver = False
self.algebraic_solver = False
self._supports_interp = False
MarcBerliner marked this conversation as resolved.
Show resolved Hide resolved
self._on_extrapolation = "warn"
self.computed_var_fcns = {}
self._mp_context = self.get_platform_context(platform.system())
Expand Down Expand Up @@ -664,6 +665,7 @@ def solve(
inputs=None,
nproc=None,
calculate_sensitivities=False,
t_interp=None,
):
"""
Execute the solver setup and calculate the solution of the model at
Expand All @@ -687,6 +689,9 @@ def solve(
Whether the solver calculates sensitivities of all input parameters. Defaults to False.
If only a subset of sensitivities are required, can also pass a
list of input parameter names
t_interp : None, list or ndarray, optional
The times (in seconds) at which to interpolate the solution. Defaults to None.
Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`).

Returns
-------
Expand Down Expand Up @@ -735,13 +740,15 @@ def solve(
"initial time and tf is the final time, but has been provided "
f"as a list of length {len(t_eval)}."
)
else:
elif not self._supports_interp:
t_eval = np.linspace(t_eval[0], t_eval[-1], 100)

# Make sure t_eval is monotonic
if (np.diff(t_eval) < 0).any():
raise pybamm.SolverError("t_eval must increase monotonically")

t_interp = self.process_t_interp(t_interp)

# Set up inputs
#
# Argument "inputs" can be either a list of input dicts or
Expand Down Expand Up @@ -860,6 +867,7 @@ def solve(
model,
t_eval[start_index:end_index],
model_inputs_list[0],
t_interp=t_interp,
)
new_solutions = [new_solution]
elif model.convert_to_format == "jax":
Expand All @@ -868,6 +876,7 @@ def solve(
model,
t_eval[start_index:end_index],
model_inputs_list,
t_interp,
)
else:
with mp.get_context(self._mp_context).Pool(processes=nproc) as p:
Expand All @@ -877,6 +886,7 @@ def solve(
[model] * ninputs,
[t_eval[start_index:end_index]] * ninputs,
model_inputs_list,
[t_interp] * ninputs,
),
)
p.close()
Expand Down Expand Up @@ -1044,6 +1054,23 @@ def _check_events_with_initialization(t_eval, model, inputs_dict):
f"Events {event_names} are non-positive at initial conditions"
)

def process_t_interp(self, t_interp):
# set a variable for this
no_interp = (not self._supports_interp) and (
t_interp is not None and len(t_interp) != 0
)
if no_interp:
warnings.warn(
f"Explicit interpolation times not implemented for {self.name}",
pybamm.SolverWarning,
stacklevel=2,
)

if no_interp or t_interp is None:
t_interp = np.empty(0)

return t_interp

def step(
self,
old_solution,
Expand All @@ -1053,6 +1080,7 @@ def step(
npts=None,
inputs=None,
save=True,
t_interp=None,
):
"""
Step the solution of the model forward by a given time increment. The
Expand All @@ -1069,14 +1097,17 @@ def step(
dt : numeric type
The timestep (in seconds) over which to step the solution
t_eval : list or numpy.ndarray, optional
An array of times at which to return the solution during the step
An array of times at which to stop the simulation and return the solution during the step
(Note: t_eval is the time measured from the start of the step, so should start at 0 and end at dt).
By default, the solution is returned at t0 and t0 + dt.
npts : deprecated
inputs : dict, optional
Any input parameters to pass to the model when solving
save : bool, optional
Save solution with all previous timesteps. Defaults to True.
t_interp : None, list or ndarray, optional
The times (in seconds) at which to interpolate the solution. Defaults to None.
Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`).
Raises
------
:class:`pybamm.ModelError`
Expand Down Expand Up @@ -1123,8 +1154,11 @@ def step(
else:
pass

t_interp = self.process_t_interp(t_interp)

t_start = old_solution.t[-1]
t_eval = t_start + t_eval
t_interp = t_start + t_interp
t_end = t_start + dt

if t_start == 0:
Expand Down Expand Up @@ -1187,7 +1221,7 @@ def step(
# Step
pybamm.logger.verbose(f"Stepping for {t_start_shifted:.0f} < t < {t_end:.0f}")
timer.reset()
solution = self._integrate(model, t_eval, model_inputs)
solution = self._integrate(model, t_eval, model_inputs, t_interp)
solution.solve_time = timer.time()

# Check if extrapolation occurred
Expand Down
3 changes: 2 additions & 1 deletion src/pybamm/solvers/c_solvers/idaklu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ PYBIND11_MODULE(idaklu, m)
py::class_<IDAKLUSolver>(m, "IDAKLUSolver")
.def("solve", &IDAKLUSolver::solve,
"perform a solve",
py::arg("t"),
py::arg("t_eval"),
py::arg("t_interp"),
py::arg("y0"),
py::arg("yp0"),
py::arg("inputs"),
Expand Down
3 changes: 2 additions & 1 deletion src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class IDAKLUSolver
* @brief Abstract solver method that returns a Solution class
*/
virtual Solution solve(
np_array t_np,
np_array t_eval_np,
np_array t_interp_np,
np_array y0_np,
np_array yp0_np,
np_array_dense inputs) = 0;
Expand Down
Loading
Loading