diff --git a/CHANGELOG.md b/CHANGELOG.md index 57e0aad007..48d0d7ba32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ ## Optimizations +- Improved adaptive time-stepping performance of the (`IDAKLUSolver`). ([#4351](https://github.com/pybamm-team/PyBaMM/pull/4351)) - Improved performance and reliability of DAE consistent initialization. ([#4301](https://github.com/pybamm-team/PyBaMM/pull/4301)) - Replaced rounded Faraday constant with its exact value in `bpx.py` for better comparison between different tools. ([#4290](https://github.com/pybamm-team/PyBaMM/pull/4290)) diff --git a/CMakeLists.txt b/CMakeLists.txt index 42ab10ee69..8b3a2adfe5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -91,6 +91,7 @@ pybind11_add_module(idaklu src/pybamm/solvers/c_solvers/idaklu/IdakluJax.cpp src/pybamm/solvers/c_solvers/idaklu/IdakluJax.hpp src/pybamm/solvers/c_solvers/idaklu/common.hpp + src/pybamm/solvers/c_solvers/idaklu/common.cpp src/pybamm/solvers/c_solvers/idaklu/python.hpp src/pybamm/solvers/c_solvers/idaklu/python.cpp src/pybamm/solvers/c_solvers/idaklu/Solution.cpp diff --git a/setup.py b/setup.py index 95108454ed..6ceb049b31 100644 --- a/setup.py +++ b/setup.py @@ -322,6 +322,7 @@ def compile_KLU(): "src/pybamm/solvers/c_solvers/idaklu/IdakluJax.cpp", "src/pybamm/solvers/c_solvers/idaklu/IdakluJax.hpp", "src/pybamm/solvers/c_solvers/idaklu/common.hpp", + "src/pybamm/solvers/c_solvers/idaklu/common.cpp", "src/pybamm/solvers/c_solvers/idaklu/python.hpp", "src/pybamm/solvers/c_solvers/idaklu/python.cpp", "src/pybamm/solvers/c_solvers/idaklu/Solution.cpp", diff --git a/src/pybamm/batch_study.py b/src/pybamm/batch_study.py index e854c94e00..ffa5a83530 100644 --- a/src/pybamm/batch_study.py +++ b/src/pybamm/batch_study.py @@ -106,6 +106,7 @@ def solve( calc_esoh=True, starting_solution=None, initial_soc=None, + t_interp=None, **kwargs, ): """ diff --git a/src/pybamm/experiment/experiment.py b/src/pybamm/experiment/experiment.py index 39c49780e4..fb20a0180e 100644 --- a/src/pybamm/experiment/experiment.py +++ b/src/pybamm/experiment/experiment.py @@ -40,7 +40,7 @@ class Experiment: def __init__( self, operating_conditions: list[str | tuple[str]], - period: str = "1 minute", + period: str | None = None, temperature: float | None = None, termination: list[str] | None = None, ): diff --git a/src/pybamm/experiment/step/base_step.py b/src/pybamm/experiment/step/base_step.py index 6b77bed2cf..a7dfa9c9ba 100644 --- a/src/pybamm/experiment/step/base_step.py +++ b/src/pybamm/experiment/step/base_step.py @@ -266,6 +266,84 @@ def default_duration(self, value): else: return 24 * 3600 # one day in seconds + @staticmethod + def default_period(): + return 60.0 # seconds + + def default_time_vector(self, tf, t0=0): + if self.period is None: + period = self.default_period() + else: + period = self.period + npts = max(int(round(np.abs(tf - t0) / period)) + 1, 2) + + return np.linspace(t0, tf, npts) + + def setup_timestepping(self, solver, tf, t_interp=None): + """ + Setup timestepping for the model. + + Parameters + ---------- + solver: :class`pybamm.BaseSolver` + The solver + tf: float + The final time + t_interp: np.array | None + The time points at which to interpolate the solution + """ + if solver.supports_interp: + return self._setup_timestepping(solver, tf, t_interp) + else: + return self._setup_timestepping_dense_t_eval(solver, tf, t_interp) + + def _setup_timestepping(self, solver, tf, t_interp): + """ + Setup timestepping for the model. This returns a t_eval vector that stops + only at the first and last time points. If t_interp and the period are + unspecified, then the solver will use adaptive time-stepping. For a given + period, t_interp will be set to return the solution at the end of each period + and at the final time. + + Parameters + ---------- + solver: :class`pybamm.BaseSolver` + The solver + tf: float + The final time + t_interp: np.array | None + The time points at which to interpolate the solution + """ + t_eval = np.array([0, tf]) + if t_interp is None: + if self.period is not None: + t_interp = self.default_time_vector(tf) + else: + t_interp = solver.process_t_interp(t_interp) + + return t_eval, t_interp + + def _setup_timestepping_dense_t_eval(self, solver, tf, t_interp): + """ + Setup timestepping for the model. By default, this returns a dense t_eval which + stops the solver at each point in the t_eval vector. This method is for solvers + that do not support intra-solve interpolation for the solution. + + Parameters + ---------- + solver: :class`pybamm.BaseSolver` + The solver + tf: float + The final time + t_interp: np.array | None + The time points at which to interpolate the solution + """ + t_eval = self.default_time_vector(tf) + + t_interp = solver.process_t_interp(t_interp) + + return t_eval, t_interp + def process_model(self, model, parameter_values): new_model = model.new_copy() new_parameter_values = parameter_values.copy() diff --git a/src/pybamm/simulation.py b/src/pybamm/simulation.py index a54c76ec7d..5b999d6c83 100644 --- a/src/pybamm/simulation.py +++ b/src/pybamm/simulation.py @@ -352,6 +352,7 @@ def solve( callbacks=None, showprogress=False, inputs=None, + t_interp=None, **kwargs, ): """ @@ -361,11 +362,14 @@ def solve( Parameters ---------- t_eval : numeric type, optional - The times (in seconds) at which to compute the solution. Can be - provided as an array of times at which to return the solution, or as a - list `[t0, tf]` where `t0` is the initial time and `tf` is the final time. - If provided as a list the solution is returned at 100 points within the - interval `[t0, tf]`. + The times at which to stop the integration due to a discontinuity in time. + Can be provided as an array of times at which to return the solution, or as + a list `[t0, tf]` where `t0` is the initial time and `tf` is the final + time. If the solver does not support intra-solve interpolation, providing + `t_eval` as a list returns the solution at 100 points within the interval + `[t0, tf]`. Otherwise, the solution is returned at the times specified in + `t_interp` or as a result of the adaptive time-stepping solution. See the + `t_interp` argument for more details. If not using an experiment or running a drive cycle simulation (current provided as data) `t_eval` *must* be provided. @@ -400,6 +404,9 @@ def solve( Whether to show a progress bar for cycling. If true, shows a progress bar for cycles. Has no effect when not used with an experiment. Default is False. + t_interp : None, list or ndarray, optional + The times (in seconds) at which to interpolate the solution. Defaults to None. + Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`). **kwargs Additional key-word arguments passed to `solver.solve`. See :meth:`pybamm.BaseSolver.solve`. @@ -486,7 +493,7 @@ def solve( ) self._solution = solver.solve( - self._built_model, t_eval, inputs=inputs, **kwargs + self._built_model, t_eval, inputs=inputs, t_interp=t_interp, **kwargs ) elif self.operating_mode == "with experiment": @@ -687,13 +694,17 @@ def solve( "start time": start_time, } # Make sure we take at least 2 timesteps - npts = max(int(round(dt / step.period)) + 1, 2) + t_eval, t_interp_processed = step.setup_timestepping( + solver, dt, t_interp + ) + try: step_solution = solver.step( current_solution, model, dt, - t_eval=np.linspace(0, dt, npts), + t_eval, + t_interp=t_interp_processed, save=False, inputs=inputs, **kwargs, diff --git a/src/pybamm/solvers/algebraic_solver.py b/src/pybamm/solvers/algebraic_solver.py index 5811e3b16d..9b6663d007 100644 --- a/src/pybamm/solvers/algebraic_solver.py +++ b/src/pybamm/solvers/algebraic_solver.py @@ -36,7 +36,7 @@ def __init__(self, method="lm", tol=1e-6, extra_options=None): self.tol = tol self.extra_options = extra_options or {} self.name = f"Algebraic solver ({method})" - self.algebraic_solver = True + self._algebraic_solver = True pybamm.citations.register("Virtanen2020") @property @@ -47,7 +47,7 @@ def tol(self): def tol(self, value): self._tol = value - def _integrate(self, model, t_eval, inputs_dict=None): + def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): """ Calculate the solution of the algebraic equations through root-finding diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index b4ff3a5774..9027bd51c4 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -63,12 +63,25 @@ def __init__( # Defaults, can be overwritten by specific solver self.name = "Base solver" - self.ode_solver = False - self.algebraic_solver = False + self._ode_solver = False + self._algebraic_solver = False + self._supports_interp = False self._on_extrapolation = "warn" self.computed_var_fcns = {} self._mp_context = self.get_platform_context(platform.system()) + @property + def ode_solver(self): + return self._ode_solver + + @property + def algebraic_solver(self): + return self._algebraic_solver + + @property + def supports_interp(self): + return self._supports_interp + @property def root_method(self): return self._root_method @@ -107,7 +120,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): inputs : dict, optional Any input parameters to pass to the model when solving t_eval : numeric type, optional - The times (in seconds) at which to compute the solution + The times at which to stop the integration due to a discontinuity in time. """ inputs = inputs or {} @@ -321,7 +334,7 @@ def _check_and_prepare_model_inplace(self, model, inputs, ics_only): f"Cannot use ODE solver '{self.name}' to solve DAE model" ) # Check model.rhs for algebraic solvers - if self.algebraic_solver is True and len(model.rhs) > 0: + if self._algebraic_solver is True and len(model.rhs) > 0: raise pybamm.SolverError( """Cannot use algebraic solver to solve model with time derivatives""" ) @@ -614,7 +627,7 @@ def _set_consistent_initialization(self, model, time, inputs_dict): """ - if self.algebraic_solver or model.len_alg == 0: + if self._algebraic_solver or model.len_alg == 0: # Don't update model.y0 return @@ -664,6 +677,7 @@ def solve( inputs=None, nproc=None, calculate_sensitivities=False, + t_interp=None, ): """ Execute the solver setup and calculate the solution of the model at @@ -687,6 +701,9 @@ def solve( Whether the solver calculates sensitivities of all input parameters. Defaults to False. If only a subset of sensitivities are required, can also pass a list of input parameter names + t_interp : None, list or ndarray, optional + The times (in seconds) at which to interpolate the solution. Defaults to None. + Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`). Returns ------- @@ -720,13 +737,13 @@ def solve( # t_eval can only be None if the solver is an algebraic solver. In that case # set it to 0 if t_eval is None: - if self.algebraic_solver is False: + if self._algebraic_solver is False: raise ValueError("t_eval cannot be None") t_eval = np.array([0]) # If t_eval is provided as [t0, tf] return the solution at 100 points elif isinstance(t_eval, list): - if len(t_eval) == 1 and self.algebraic_solver is True: + if len(t_eval) == 1 and self._algebraic_solver is True: t_eval = np.array(t_eval) elif len(t_eval) != 2: raise pybamm.SolverError( @@ -735,13 +752,15 @@ def solve( "initial time and tf is the final time, but has been provided " f"as a list of length {len(t_eval)}." ) - else: + elif not self.supports_interp: t_eval = np.linspace(t_eval[0], t_eval[-1], 100) # Make sure t_eval is monotonic if (np.diff(t_eval) < 0).any(): raise pybamm.SolverError("t_eval must increase monotonically") + t_interp = self.process_t_interp(t_interp) + # Set up inputs # # Argument "inputs" can be either a list of input dicts or @@ -813,7 +832,7 @@ def solve( self._model_set_up[model]["initial conditions"] != model.concatenated_initial_conditions ): - if self.algebraic_solver: + if self._algebraic_solver: # For an algebraic solver, we don't need to set up the initial # conditions function and we can just evaluate # model.concatenated_initial_conditions @@ -860,6 +879,7 @@ def solve( model, t_eval[start_index:end_index], model_inputs_list[0], + t_interp=t_interp, ) new_solutions = [new_solution] elif model.convert_to_format == "jax": @@ -868,6 +888,7 @@ def solve( model, t_eval[start_index:end_index], model_inputs_list, + t_interp, ) else: with mp.get_context(self._mp_context).Pool(processes=nproc) as p: @@ -877,6 +898,7 @@ def solve( [model] * ninputs, [t_eval[start_index:end_index]] * ninputs, model_inputs_list, + [t_interp] * ninputs, ), ) p.close() @@ -940,7 +962,7 @@ def solve( # Raise error if solutions[0] only contains one timestep (except for algebraic # solvers, where we may only expect one time in the solution) if ( - self.algebraic_solver is False + self._algebraic_solver is False and len(solutions[0].all_ts) == 1 and len(solutions[0].all_ts[0]) == 1 ): @@ -1044,6 +1066,23 @@ def _check_events_with_initialization(t_eval, model, inputs_dict): f"Events {event_names} are non-positive at initial conditions" ) + def process_t_interp(self, t_interp): + # set a variable for this + no_interp = (not self.supports_interp) and ( + t_interp is not None and len(t_interp) != 0 + ) + if no_interp: + warnings.warn( + f"Explicit interpolation times not implemented for {self.name}", + pybamm.SolverWarning, + stacklevel=2, + ) + + if no_interp or t_interp is None: + t_interp = np.empty(0) + + return t_interp + def step( self, old_solution, @@ -1053,6 +1092,7 @@ def step( npts=None, inputs=None, save=True, + t_interp=None, ): """ Step the solution of the model forward by a given time increment. The @@ -1069,7 +1109,7 @@ def step( dt : numeric type The timestep (in seconds) over which to step the solution t_eval : list or numpy.ndarray, optional - An array of times at which to return the solution during the step + An array of times at which to stop the simulation and return the solution during the step (Note: t_eval is the time measured from the start of the step, so should start at 0 and end at dt). By default, the solution is returned at t0 and t0 + dt. npts : deprecated @@ -1077,6 +1117,9 @@ def step( Any input parameters to pass to the model when solving save : bool, optional Save solution with all previous timesteps. Defaults to True. + t_interp : None, list or ndarray, optional + The times (in seconds) at which to interpolate the solution. Defaults to None. + Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`). Raises ------ :class:`pybamm.ModelError` @@ -1123,8 +1166,11 @@ def step( else: pass + t_interp = self.process_t_interp(t_interp) + t_start = old_solution.t[-1] t_eval = t_start + t_eval + t_interp = t_start + t_interp t_end = t_start + dt if t_start == 0: @@ -1136,6 +1182,8 @@ def step( # the start of the next step t_start_shifted = t_start + step_start_offset t_eval[0] = t_start_shifted + if t_interp.size > 0 and t_interp[0] == t_start: + t_interp[0] = t_start_shifted # Set timer timer = pybamm.Timer() @@ -1187,7 +1235,7 @@ def step( # Step pybamm.logger.verbose(f"Stepping for {t_start_shifted:.0f} < t < {t_end:.0f}") timer.reset() - solution = self._integrate(model, t_eval, model_inputs) + solution = self._integrate(model, t_eval, model_inputs, t_interp) solution.solve_time = timer.time() # Check if extrapolation occurred diff --git a/src/pybamm/solvers/c_solvers/idaklu.cpp b/src/pybamm/solvers/c_solvers/idaklu.cpp index 3427c01853..bb9466d40b 100644 --- a/src/pybamm/solvers/c_solvers/idaklu.cpp +++ b/src/pybamm/solvers/c_solvers/idaklu.cpp @@ -59,7 +59,8 @@ PYBIND11_MODULE(idaklu, m) py::class_(m, "IDAKLUSolver") .def("solve", &IDAKLUSolver::solve, "perform a solve", - py::arg("t"), + py::arg("t_eval"), + py::arg("t_interp"), py::arg("y0"), py::arg("yp0"), py::arg("inputs"), diff --git a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp index 26e587e424..29b451e6d3 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp @@ -27,7 +27,8 @@ class IDAKLUSolver * @brief Abstract solver method that returns a Solution class */ virtual Solution solve( - np_array t_np, + np_array t_eval_np, + np_array t_interp_np, np_array y0_np, np_array yp0_np, np_array_dense inputs) = 0; diff --git a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp index 98148a3c9f..ca710fbff6 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp @@ -3,6 +3,9 @@ #include "IDAKLUSolver.hpp" #include "common.hpp" +#include +using std::vector; + #include "Options.hpp" #include "Solution.hpp" #include "sundials_legacy_wrapper.hpp" @@ -46,26 +49,32 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver void *ida_mem = nullptr; np_array atol_np; np_array rhs_alg_id; - int number_of_states; // cppcheck-suppress unusedStructMember - int number_of_parameters; // cppcheck-suppress unusedStructMember - int number_of_events; // cppcheck-suppress unusedStructMember + int const number_of_states; // cppcheck-suppress unusedStructMember + int const number_of_parameters; // cppcheck-suppress unusedStructMember + int const number_of_events; // cppcheck-suppress unusedStructMember int precon_type; // cppcheck-suppress unusedStructMember N_Vector yy, yp, avtol; // y, y', and absolute tolerance N_Vector *yyS; // cppcheck-suppress unusedStructMember N_Vector *ypS; // cppcheck-suppress unusedStructMember N_Vector id; // rhs_alg_id realtype rtol; - const int jac_times_cjmass_nnz; // cppcheck-suppress unusedStructMember - int jac_bandwidth_lower; // cppcheck-suppress unusedStructMember - int jac_bandwidth_upper; // cppcheck-suppress unusedStructMember + int const jac_times_cjmass_nnz; // cppcheck-suppress unusedStructMember + int const jac_bandwidth_lower; // cppcheck-suppress unusedStructMember + int const jac_bandwidth_upper; // cppcheck-suppress unusedStructMember SUNMatrix J; SUNLinearSolver LS = nullptr; std::unique_ptr functions; - std::vector res; - std::vector res_dvar_dy; - std::vector res_dvar_dp; - SetupOptions setup_opts; - SolverOptions solver_opts; + vector res; + vector res_dvar_dy; + vector res_dvar_dp; + bool const sensitivity; // cppcheck-suppress unusedStructMember + bool const save_outputs_only; // cppcheck-suppress unusedStructMember + int length_of_return_vector; // cppcheck-suppress unusedStructMember + vector t; // cppcheck-suppress unusedStructMember + vector> y; // cppcheck-suppress unusedStructMember + vector>> yS; // cppcheck-suppress unusedStructMember + SetupOptions const setup_opts; + SolverOptions const solver_opts; #if SUNDIALS_VERSION_MAJOR >= 6 SUNContext sunctx; @@ -94,36 +103,12 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver */ ~IDAKLUSolverOpenMP(); - /** - * Evaluate functions (including sensitivies) for each requested - * variable and store - * @brief Evaluate functions - */ - void CalcVars( - realtype *y_return, - size_t length_of_return_vector, - size_t t_i, - realtype *tret, - realtype *yval, - const std::vector& ySval, - realtype *yS_return, - size_t *ySk); - - /** - * @brief Evaluate functions for sensitivities - */ - void CalcVarsSensitivities( - realtype *tret, - realtype *yval, - const std::vector& ySval, - realtype *yS_return, - size_t *ySk); - /** * @brief The main solve method that solves for each variable and time step */ Solution solve( - np_array t_np, + np_array t_eval_np, + np_array t_interp_np, np_array y0_np, np_array yp0_np, np_array_dense inputs) override; @@ -143,6 +128,16 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver */ void SetMatrix(); + /** + * @brief Get the length of the return vector + */ + int ReturnVectorLength(); + + /** + * @brief Initialize the storage for the solution + */ + void InitializeStorage(int const N); + /** * @brief Apply user-configurable IDA options */ @@ -152,6 +147,82 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver * @brief Check the return flag for errors */ void CheckErrors(int const & flag); + + /** + * @brief Print the solver statistics + */ + void PrintStats(); + + /** + * @brief Extend the adaptive arrays by 1 + */ + void ExtendAdaptiveArrays(); + + /** + * @brief Set the step values + */ + void SetStep( + realtype &t_val, + realtype *y_val, + vector const &yS_val, + int &i_save + ); + + /** + * @brief Save the interpolated step values + */ + void SetStepInterp( + int &i_interp, + realtype &t_interp_next, + vector const &t_interp, + realtype &t_val, + realtype &t_prev, + realtype const &t_next, + realtype *y_val, + vector const &yS_val, + int &i_save + ); + + /** + * @brief Save y and yS at the current time + */ + void SetStepFull( + realtype &t_val, + realtype *y_val, + vector const &yS_val, + int &i_save + ); + + /** + * @brief Save yS at the current time + */ + void SetStepFullSensitivities( + realtype &t_val, + realtype *y_val, + vector const &yS_val, + int &i_save + ); + + /** + * @brief Save the output function results at the requested time + */ + void SetStepOutput( + realtype &t_val, + realtype *y_val, + const vector &yS_val, + int &i_save + ); + + /** + * @brief Save the output function sensitivities at the requested time + */ + void SetStepOutputSensitivities( + realtype &t_val, + realtype *y_val, + const vector &yS_val, + int &i_save + ); + }; #include "IDAKLUSolverOpenMP.inl" diff --git a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl index b309fd6028..de6c43466a 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl +++ b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl @@ -1,5 +1,8 @@ #include "Expressions/Expressions.hpp" #include "sundials_functions.hpp" +#include + +#include "common.hpp" template IDAKLUSolverOpenMP::IDAKLUSolverOpenMP( @@ -24,6 +27,8 @@ IDAKLUSolverOpenMP::IDAKLUSolverOpenMP( jac_bandwidth_lower(jac_bandwidth_lower_input), jac_bandwidth_upper(jac_bandwidth_upper_input), functions(std::move(functions_arg)), + sensitivity(number_of_parameters > 0), + save_outputs_only(functions->var_fcns.size() > 0), setup_opts(setup_input), solver_opts(solver_input) { @@ -40,7 +45,7 @@ IDAKLUSolverOpenMP::IDAKLUSolverOpenMP( // create the vector of initial values AllocateVectors(); - if (number_of_parameters > 0) { + if (sensitivity) { yyS = N_VCloneVectorArray(number_of_parameters, yy); ypS = N_VCloneVectorArray(number_of_parameters, yp); } @@ -88,6 +93,55 @@ void IDAKLUSolverOpenMP::AllocateVectors() { id = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); } +template +void IDAKLUSolverOpenMP::InitializeStorage(int const N) { + length_of_return_vector = ReturnVectorLength(); + + t = vector(N, 0.0); + + y = vector>( + N, + vector(length_of_return_vector, 0.0) + ); + + yS = vector>>( + N, + vector>( + number_of_parameters, + vector(length_of_return_vector, 0.0) + ) + ); +} + +template +int IDAKLUSolverOpenMP::ReturnVectorLength() { + if (!save_outputs_only) { + return number_of_states; + } + + // set return vectors + int length_of_return_vector = 0; + size_t max_res_size = 0; // maximum result size (for common result buffer) + size_t max_res_dvar_dy = 0, max_res_dvar_dp = 0; + // return only the requested variables list after computation + for (auto& var_fcn : functions->var_fcns) { + max_res_size = std::max(max_res_size, size_t(var_fcn->out_shape(0))); + length_of_return_vector += var_fcn->nnz_out(); + for (auto& dvar_fcn : functions->dvar_dy_fcns) { + max_res_dvar_dy = std::max(max_res_dvar_dy, size_t(dvar_fcn->out_shape(0))); + } + for (auto& dvar_fcn : functions->dvar_dp_fcns) { + max_res_dvar_dp = std::max(max_res_dvar_dp, size_t(dvar_fcn->out_shape(0))); + } + + res.resize(max_res_size); + res_dvar_dy.resize(max_res_dvar_dy); + res_dvar_dp.resize(max_res_dvar_dp); + } + + return length_of_return_vector; +} + template void IDAKLUSolverOpenMP::SetSolverOptions() { // Maximum order of the linear multistep method @@ -153,8 +207,6 @@ void IDAKLUSolverOpenMP::SetSolverOptions() { } } - - template void IDAKLUSolverOpenMP::SetMatrix() { // Create Matrix object @@ -215,7 +267,7 @@ void IDAKLUSolverOpenMP::Initialize() { CheckErrors(IDASetJacFn(ida_mem, jacobian_eval)); } - if (number_of_parameters > 0) { + if (sensitivity) { CheckErrors(IDASensInit(ida_mem, number_of_parameters, IDA_SIMULTANEOUS, sensitivities_eval, yyS, ypS)); CheckErrors(IDASensEEtolerances(ida_mem)); @@ -238,7 +290,6 @@ void IDAKLUSolverOpenMP::Initialize() { template IDAKLUSolverOpenMP::~IDAKLUSolverOpenMP() { - bool sensitivity = number_of_parameters > 0; // Free memory if (sensitivity) { IDASensFree(ida_mem); @@ -261,70 +312,10 @@ IDAKLUSolverOpenMP::~IDAKLUSolverOpenMP() { SUNContext_Free(&sunctx); } -template -void IDAKLUSolverOpenMP::CalcVars( - realtype *y_return, - size_t length_of_return_vector, - size_t t_i, - realtype *tret, - realtype *yval, - const std::vector& ySval, - realtype *yS_return, - size_t *ySk -) { - DEBUG("IDAKLUSolver::CalcVars"); - // Evaluate functions for each requested variable and store - size_t j = 0; - for (auto& var_fcn : functions->var_fcns) { - (*var_fcn)({tret, yval, functions->inputs.data()}, {&res[0]}); - // store in return vector - for (size_t jj=0; jjnnz_out(); jj++) { - y_return[t_i*length_of_return_vector + j++] = res[jj]; - } - } - // calculate sensitivities - CalcVarsSensitivities(tret, yval, ySval, yS_return, ySk); -} - -template -void IDAKLUSolverOpenMP::CalcVarsSensitivities( - realtype *tret, - realtype *yval, - const std::vector& ySval, - realtype *yS_return, - size_t *ySk -) { - DEBUG("IDAKLUSolver::CalcVarsSensitivities"); - // Calculate sensitivities - std::vector dens_dvar_dp = std::vector(number_of_parameters, 0); - for (size_t dvar_k = 0; dvar_k < functions->dvar_dy_fcns.size(); dvar_k++) { - // Isolate functions - Expression* dvar_dy = functions->dvar_dy_fcns[dvar_k]; - Expression* dvar_dp = functions->dvar_dp_fcns[dvar_k]; - // Calculate dvar/dy - (*dvar_dy)({tret, yval, functions->inputs.data()}, {&res_dvar_dy[0]}); - // Calculate dvar/dp and convert to dense array for indexing - (*dvar_dp)({tret, yval, functions->inputs.data()}, {&res_dvar_dp[0]}); - for (int k=0; knnz_out(); k++) { - dens_dvar_dp[dvar_dp->get_row()[k]] = res_dvar_dp[k]; - } - // Calculate sensitivities - for (int paramk = 0; paramk < number_of_parameters; paramk++) { - yS_return[*ySk] = dens_dvar_dp[paramk]; - for (int spk = 0; spk < dvar_dy->nnz_out(); spk++) { - yS_return[*ySk] += res_dvar_dy[spk] * ySval[paramk][dvar_dy->get_col()[spk]]; - } - (*ySk)++; - } - } -} - template Solution IDAKLUSolverOpenMP::solve( - np_array t_np, + np_array t_eval_np, + np_array t_interp_np, np_array y0_np, np_array yp0_np, np_array_dense inputs @@ -332,21 +323,78 @@ Solution IDAKLUSolverOpenMP::solve( { DEBUG("IDAKLUSolver::solve"); - int number_of_timesteps = t_np.request().size; - auto t = t_np.unchecked<1>(); - realtype t0 = RCONST(t(0)); + // If t_interp is empty, save all adaptive steps + bool save_adaptive_steps = t_interp_np.unchecked<1>().size() == 0; + + // Process the time inputs + // 1. Get the sorted and unique t_eval vector + auto const t_eval = makeSortedUnique(t_eval_np); + + // 2.1. Get the sorted and unique t_interp vector + auto const t_interp_unique_sorted = makeSortedUnique(t_interp_np); + + // 2.2 Remove the t_eval values from t_interp + auto const t_interp_setdiff = setDiff(t_interp_unique_sorted, t_eval); + + // 2.3 Finally, get the sorted and unique t_interp vector with t_eval values removed + auto const t_interp = makeSortedUnique(t_interp_setdiff); + + int const number_of_evals = t_eval.size(); + int const number_of_interps = t_interp.size(); + + // setDiff removes entries of t_interp that overlap with + // t_eval, so we need to check if we need to interpolate any unique points. + // This is not the same as save_adaptive_steps since some entries of t_interp + // may be removed by setDiff + bool save_interp_steps = number_of_interps > 0; + + // 3. Check if the timestepping entries are valid + if (number_of_evals < 2) { + throw std::invalid_argument( + "t_eval must have at least 2 entries" + ); + } else if (save_interp_steps) { + if (t_interp.front() < t_eval.front()) { + throw std::invalid_argument( + "t_interp values must be greater than the smallest t_eval value: " + + std::to_string(t_eval.front()) + ); + } else if (t_interp.back() > t_eval.back()) { + throw std::invalid_argument( + "t_interp values must be less than the greatest t_eval value: " + + std::to_string(t_eval.back()) + ); + } + } + + // Initialize length_of_return_vector, t, y, and yS + InitializeStorage(number_of_evals + number_of_interps); + + int i_save = 0; + + realtype t0 = t_eval.front(); + realtype tf = t_eval.back(); + + realtype t_val = t0; + realtype t_prev = t0; + int i_eval = 0; + + realtype t_interp_next; + int i_interp = 0; + // If t_interp is empty, save all adaptive steps + if (save_interp_steps) { + t_interp_next = t_interp[0]; + } + auto y0 = y0_np.unchecked<1>(); auto yp0 = yp0_np.unchecked<1>(); auto n_coeffs = number_of_states + number_of_parameters * number_of_states; - bool const sensitivity = number_of_parameters > 0; if (y0.size() != n_coeffs) { throw std::domain_error( "y0 has wrong size. Expected " + std::to_string(n_coeffs) + " but got " + std::to_string(y0.size())); - } - - if (yp0.size() != n_coeffs) { + } else if (yp0.size() != n_coeffs) { throw std::domain_error( "yp0 has wrong size. Expected " + std::to_string(n_coeffs) + " but got " + std::to_string(yp0.size())); @@ -358,23 +406,23 @@ Solution IDAKLUSolverOpenMP::solve( functions->inputs[i] = p_inputs(i, 0); } - // set initial conditions - realtype *yval = N_VGetArrayPointer(yy); - realtype *ypval = N_VGetArrayPointer(yp); - std::vector ySval(number_of_parameters); - std::vector ypSval(number_of_parameters); + // Setup consistent initialization + realtype *y_val = N_VGetArrayPointer(yy); + realtype *yp_val = N_VGetArrayPointer(yp); + vector yS_val(number_of_parameters); + vector ypS_val(number_of_parameters); for (int p = 0 ; p < number_of_parameters; p++) { - ySval[p] = N_VGetArrayPointer(yyS[p]); - ypSval[p] = N_VGetArrayPointer(ypS[p]); + yS_val[p] = N_VGetArrayPointer(yyS[p]); + ypS_val[p] = N_VGetArrayPointer(ypS[p]); for (int i = 0; i < number_of_states; i++) { - ySval[p][i] = y0[i + (p + 1) * number_of_states]; - ypSval[p][i] = yp0[i + (p + 1) * number_of_states]; + yS_val[p][i] = y0[i + (p + 1) * number_of_states]; + ypS_val[p][i] = yp0[i + (p + 1) * number_of_states]; } } for (int i = 0; i < number_of_states; i++) { - yval[i] = y0[i]; - ypval[i] = yp0[i]; + y_val[i] = y0[i]; + yp_val[i] = yp0[i]; } SetSolverOptions(); @@ -384,54 +432,127 @@ Solution IDAKLUSolverOpenMP::solve( CheckErrors(IDASensReInit(ida_mem, IDA_SIMULTANEOUS, yyS, ypS)); } - // correct initial values + // Prepare first time step + i_eval = 1; + realtype t_eval_next = t_eval[i_eval]; + + // Consistent initialization int const init_type = solver_opts.init_all_y_ic ? IDA_Y_INIT : IDA_YA_YDP_INIT; if (solver_opts.calc_ic) { DEBUG("IDACalcIC"); // IDACalcIC will throw a warning if it fails to find initial conditions - IDACalcIC(ida_mem, init_type, t(1)); + IDACalcIC(ida_mem, init_type, t_eval_next); } if (sensitivity) { - CheckErrors(IDAGetSens(ida_mem, &t0, yyS)); + CheckErrors(IDAGetSens(ida_mem, &t_val, yyS)); } - realtype tret; - realtype t_final = t(number_of_timesteps - 1); + // Store Consistent initialization + SetStep(t0, y_val, yS_val, i_save); - // set return vectors - int length_of_return_vector = 0; - int length_of_final_sv_slice = 0; - size_t max_res_size = 0; // maximum result size (for common result buffer) - size_t max_res_dvar_dy = 0, max_res_dvar_dp = 0; - if (functions->var_fcns.size() > 0) { - // return only the requested variables list after computation - for (auto& var_fcn : functions->var_fcns) { - max_res_size = std::max(max_res_size, size_t(var_fcn->out_shape(0))); - length_of_return_vector += var_fcn->nnz_out(); - for (auto& dvar_fcn : functions->dvar_dy_fcns) { - max_res_dvar_dy = std::max(max_res_dvar_dy, size_t(dvar_fcn->out_shape(0))); + // Set the initial stop time + IDASetStopTime(ida_mem, t_eval_next); + + // Solve the system + int retval; + DEBUG("IDASolve"); + while (true) { + // Progress one step + retval = IDASolve(ida_mem, tf, &t_val, yy, yp, IDA_ONE_STEP); + + if (retval < 0) { + // failed + break; + } else if (t_prev == t_val) { + // IDA sometimes returns an identical time point twice + // instead of erroring. Assign a retval and break + retval = IDA_ERR_FAIL; + break; + } + + bool hit_tinterp = save_interp_steps && t_interp_next >= t_prev; + bool hit_teval = retval == IDA_TSTOP_RETURN; + bool hit_final_time = t_val >= tf || (hit_teval && i_eval == number_of_evals); + bool hit_event = retval == IDA_ROOT_RETURN; + bool hit_adaptive = save_adaptive_steps && retval == IDA_SUCCESS; + + if (sensitivity) { + CheckErrors(IDAGetSens(ida_mem, &t_val, yyS)); + } + + if (hit_tinterp) { + // Save the interpolated state at t_prev < t < t_val, for all t in t_interp + SetStepInterp(i_interp, + t_interp_next, + t_interp, + t_val, + t_prev, + t_eval_next, + y_val, + yS_val, + i_save); + } + + if (hit_adaptive || hit_teval || hit_event) { + if (hit_tinterp) { + // Reset the states and sensitivities at t = t_val + CheckErrors(IDAGetDky(ida_mem, t_val, 0, yy)); + if (sensitivity) { + CheckErrors(IDAGetSens(ida_mem, &t_val, yyS)); + } } - for (auto& dvar_fcn : functions->dvar_dp_fcns) { - max_res_dvar_dp = std::max(max_res_dvar_dp, size_t(dvar_fcn->out_shape(0))); + + // Save the current state at t_val + if (hit_adaptive) { + // Dynamically allocate memory for the adaptive step + ExtendAdaptiveArrays(); } - length_of_final_sv_slice = number_of_states; + SetStep(t_val, y_val, yS_val, i_save); } - } else { - // Return full y state-vector - length_of_return_vector = number_of_states; + + if (hit_final_time || hit_event) { + // Successful simulation. Exit the while loop + break; + } else if (hit_teval) { + // Set the next stop time + i_eval += 1; + t_eval_next = t_eval[i_eval]; + CheckErrors(IDASetStopTime(ida_mem, t_eval_next)); + + // Reinitialize the solver to deal with the discontinuity at t = t_val. + // We must reinitialize the algebraic terms, so do not use init_type. + IDACalcIC(ida_mem, IDA_YA_YDP_INIT, t_eval_next); + CheckErrors(IDAReInit(ida_mem, t_val, yy, yp)); + if (sensitivity) { + CheckErrors(IDASensReInit(ida_mem, IDA_SIMULTANEOUS, yyS, ypS)); + } + } + + t_prev = t_val; } - realtype *t_return = new realtype[number_of_timesteps]; - realtype *y_return = new realtype[number_of_timesteps * - length_of_return_vector]; - realtype *yS_return = new realtype[number_of_parameters * - number_of_timesteps * - length_of_return_vector]; + + int const length_of_final_sv_slice = save_outputs_only ? number_of_states : 0; realtype *yterm_return = new realtype[length_of_final_sv_slice]; + if (save_outputs_only) { + // store final state slice if outout variables are specified + yterm_return = y_val; + } + + if (solver_opts.print_stats) { + PrintStats(); + } + + int const number_of_timesteps = i_save; + int count; - res.resize(max_res_size); - res_dvar_dy.resize(max_res_dvar_dy); - res_dvar_dp.resize(max_res_dvar_dp); + // Copy the data to return as numpy arrays + + // Time, t + realtype *t_return = new realtype[number_of_timesteps]; + for (size_t i = 0; i < number_of_timesteps; i++) { + t_return[i] = t[i]; + } py::capsule free_t_when_done( t_return, @@ -440,6 +561,23 @@ Solution IDAKLUSolverOpenMP::solve( delete[] vect; } ); + + np_array t_ret = np_array( + number_of_timesteps, + &t_return[0], + free_t_when_done + ); + + // States, y + realtype *y_return = new realtype[number_of_timesteps * length_of_return_vector]; + count = 0; + for (size_t i = 0; i < number_of_timesteps; i++) { + for (size_t j = 0; j < length_of_return_vector; j++) { + y_return[count] = y[i][j]; + count++; + } + } + py::capsule free_y_when_done( y_return, [](void *f) { @@ -447,6 +585,35 @@ Solution IDAKLUSolverOpenMP::solve( delete[] vect; } ); + + np_array y_ret = np_array( + number_of_timesteps * length_of_return_vector, + &y_return[0], + free_y_when_done + ); + + // Sensitivity states, yS + // Note: Ordering of vector is different if computing outputs vs returning + // the complete state vector + auto const arg_sens0 = (save_outputs_only ? number_of_timesteps : number_of_parameters); + auto const arg_sens1 = (save_outputs_only ? length_of_return_vector : number_of_timesteps); + auto const arg_sens2 = (save_outputs_only ? number_of_parameters : length_of_return_vector); + + realtype *yS_return = new realtype[arg_sens0 * arg_sens1 * arg_sens2]; + count = 0; + for (size_t idx0 = 0; idx0 < arg_sens0; idx0++) { + for (size_t idx1 = 0; idx1 < arg_sens1; idx1++) { + for (size_t idx2 = 0; idx2 < arg_sens2; idx2++) { + auto i = (save_outputs_only ? idx0 : idx1); + auto j = (save_outputs_only ? idx1 : idx2); + auto k = (save_outputs_only ? idx2 : idx0); + + yS_return[count] = yS[i][k][j]; + count++; + } + } + } + py::capsule free_yS_when_done( yS_return, [](void *f) { @@ -454,6 +621,18 @@ Solution IDAKLUSolverOpenMP::solve( delete[] vect; } ); + + np_array yS_ret = np_array( + vector { + arg_sens0, + arg_sens1, + arg_sens2 + }, + &yS_return[0], + free_yS_when_done + ); + + // Final state slice, yterm py::capsule free_yterm_when_done( yterm_return, [](void *f) { @@ -462,167 +641,188 @@ Solution IDAKLUSolverOpenMP::solve( } ); - // Initial state (t_i=0) - int t_i = 0; - size_t ySk = 0; - t_return[t_i] = t(t_i); - if (functions->var_fcns.size() > 0) { - // Evaluate functions for each requested variable and store - CalcVars(y_return, length_of_return_vector, t_i, - &tret, yval, ySval, yS_return, &ySk); + np_array y_term = np_array( + length_of_final_sv_slice, + &yterm_return[0], + free_yterm_when_done + ); + + // Store the solution + Solution sol(retval, t_ret, y_ret, yS_ret, y_term); + + return sol; +} + +template +void IDAKLUSolverOpenMP::ExtendAdaptiveArrays() { + DEBUG("IDAKLUSolver::ExtendAdaptiveArrays"); + // Time + t.emplace_back(0.0); + + // States + y.emplace_back(length_of_return_vector, 0.0); + + // Sensitivity + if (sensitivity) { + yS.emplace_back(number_of_parameters, vector(length_of_return_vector, 0.0)); + } +} + +template +void IDAKLUSolverOpenMP::SetStep( + realtype &tval, + realtype *y_val, + vector const &yS_val, + int &i_save +) { + // Set adaptive step results for y and yS + DEBUG("IDAKLUSolver::SetStep"); + + // Time + t[i_save] = tval; + + if (save_outputs_only) { + SetStepOutput(tval, y_val, yS_val, i_save); } else { - // Retain complete copy of the state vector - for (int j = 0; j < number_of_states; j++) { - y_return[j] = yval[j]; - } - for (int j = 0; j < number_of_parameters; j++) { - const int base_index = j * number_of_timesteps * number_of_states; - for (int k = 0; k < number_of_states; k++) { - yS_return[base_index + k] = ySval[j][k]; - } - } + SetStepFull(tval, y_val, yS_val, i_save); } - // Subsequent states (t_i>0) - int retval; - t_i = 1; - while (true) { - realtype t_next = t(t_i); - IDASetStopTime(ida_mem, t_next); - DEBUG("IDASolve"); - retval = IDASolve(ida_mem, t_final, &tret, yy, yp, IDA_NORMAL); - - if (!(retval == IDA_TSTOP_RETURN || - retval == IDA_SUCCESS || - retval == IDA_ROOT_RETURN)) { - // failed - break; - } + i_save++; +} + +template +void IDAKLUSolverOpenMP::SetStepInterp( + int &i_interp, + realtype &t_interp_next, + vector const &t_interp, + realtype &t_val, + realtype &t_prev, + realtype const &t_eval_next, + realtype *y_val, + vector const &yS_val, + int &i_save + ) { + // Save the state at the requested time + DEBUG("IDAKLUSolver::SetStepInterp"); + + while (i_interp <= (t_interp.size()-1) && t_interp_next <= t_val) { + CheckErrors(IDAGetDky(ida_mem, t_interp_next, 0, yy)); if (sensitivity) { - CheckErrors(IDAGetSens(ida_mem, &tret, yyS)); + CheckErrors(IDAGetSensDky(ida_mem, t_interp_next, 0, yyS)); } - // Evaluate and store results for the time step - t_return[t_i] = tret; - if (functions->var_fcns.size() > 0) { - // Evaluate functions for each requested variable and store - // NOTE: Indexing of yS_return is (time:var:param) - CalcVars(y_return, length_of_return_vector, t_i, - &tret, yval, ySval, yS_return, &ySk); - } else { - // Retain complete copy of the state vector - for (int j = 0; j < number_of_states; j++) { - y_return[t_i * number_of_states + j] = yval[j]; - } - for (int j = 0; j < number_of_parameters; j++) { - const int base_index = - j * number_of_timesteps * number_of_states + - t_i * number_of_states; - for (int k = 0; k < number_of_states; k++) { - // NOTE: Indexing of yS_return is (time:param:yvec) - yS_return[base_index + k] = ySval[j][k]; - } - } - } - t_i += 1; + // Memory is already allocated for the interpolated values + SetStep(t_interp_next, y_val, yS_val, i_save); - if (retval == IDA_SUCCESS || retval == IDA_ROOT_RETURN) { - if (functions->var_fcns.size() > 0) { - // store final state slice if outout variables are specified - yterm_return = yval; - } + i_interp++; + if (i_interp == (t_interp.size())) { + // Reached the final t_interp value break; } + t_interp_next = t_interp[i_interp]; } +} - np_array t_ret = np_array( - t_i, - &t_return[0], - free_t_when_done - ); - np_array y_ret = np_array( - t_i * length_of_return_vector, - &y_return[0], - free_y_when_done - ); - // Note: Ordering of vector is different if computing variables vs returning - // the complete state vector - np_array yS_ret; - if (functions->var_fcns.size() > 0) { - yS_ret = np_array( - std::vector { - number_of_timesteps, - length_of_return_vector, - number_of_parameters - }, - &yS_return[0], - free_yS_when_done - ); - } else { - yS_ret = np_array( - std::vector { - number_of_parameters, - number_of_timesteps, - length_of_return_vector - }, - &yS_return[0], - free_yS_when_done - ); +template +void IDAKLUSolverOpenMP::SetStepFull( + realtype &tval, + realtype *y_val, + vector const &yS_val, + int &i_save +) { + // Set adaptive step results for y and yS + DEBUG("IDAKLUSolver::SetStepFull"); + + // States + auto &y_back = y[i_save]; + for (size_t j = 0; j < number_of_states; ++j) { + y_back[j] = y_val[j]; } - np_array y_term = np_array( - length_of_final_sv_slice, - &yterm_return[0], - free_yterm_when_done - ); - Solution sol(retval, t_ret, y_ret, yS_ret, y_term); + // Sensitivity + if (sensitivity) { + SetStepFullSensitivities(tval, y_val, yS_val, i_save); + } +} - if (solver_opts.print_stats) { - long nsteps, nrevals, nlinsetups, netfails; - int klast, kcur; - realtype hinused, hlast, hcur, tcur; - - CheckErrors(IDAGetIntegratorStats( - ida_mem, - &nsteps, - &nrevals, - &nlinsetups, - &netfails, - &klast, - &kcur, - &hinused, - &hlast, - &hcur, - &tcur - )); - - long nniters, nncfails; - CheckErrors(IDAGetNonlinSolvStats(ida_mem, &nniters, &nncfails)); - - long int ngevalsBBDP = 0; - if (setup_opts.using_iterative_solver) { - CheckErrors(IDABBDPrecGetNumGfnEvals(ida_mem, &ngevalsBBDP)); +template +void IDAKLUSolverOpenMP::SetStepFullSensitivities( + realtype &tval, + realtype *y_val, + vector const &yS_val, + int &i_save +) { + DEBUG("IDAKLUSolver::SetStepFullSensitivities"); + + // Calculate sensitivities for the full yS array + for (size_t j = 0; j < number_of_parameters; ++j) { + auto &yS_back_j = yS[i_save][j]; + auto &ySval_j = yS_val[j]; + for (size_t k = 0; k < number_of_states; ++k) { + yS_back_j[k] = ySval_j[k]; } + } +} - py::print("Solver Stats:"); - py::print("\tNumber of steps =", nsteps); - py::print("\tNumber of calls to residual function =", nrevals); - py::print("\tNumber of calls to residual function in preconditioner =", - ngevalsBBDP); - py::print("\tNumber of linear solver setup calls =", nlinsetups); - py::print("\tNumber of error test failures =", netfails); - py::print("\tMethod order used on last step =", klast); - py::print("\tMethod order used on next step =", kcur); - py::print("\tInitial step size =", hinused); - py::print("\tStep size on last step =", hlast); - py::print("\tStep size on next step =", hcur); - py::print("\tCurrent internal time reached =", tcur); - py::print("\tNumber of nonlinear iterations performed =", nniters); - py::print("\tNumber of nonlinear convergence failures =", nncfails); +template +void IDAKLUSolverOpenMP::SetStepOutput( + realtype &tval, + realtype *y_val, + const vector& yS_val, + int &i_save +) { + DEBUG("IDAKLUSolver::SetStepOutput"); + // Evaluate functions for each requested variable and store + + size_t j = 0; + for (auto& var_fcn : functions->var_fcns) { + (*var_fcn)({&tval, y_val, functions->inputs.data()}, {&res[0]}); + // store in return vector + for (size_t jj=0; jjnnz_out(); jj++) { + y[i_save][j++] = res[jj]; + } + } + // calculate sensitivities + if (sensitivity) { + SetStepOutputSensitivities(tval, y_val, yS_val, i_save); } +} - return sol; +template +void IDAKLUSolverOpenMP::SetStepOutputSensitivities( + realtype &tval, + realtype *y_val, + const vector& yS_val, + int &i_save + ) { + DEBUG("IDAKLUSolver::SetStepOutputSensitivities"); + // Calculate sensitivities + vector dens_dvar_dp = vector(number_of_parameters, 0); + for (size_t dvar_k=0; dvar_kdvar_dy_fcns.size(); dvar_k++) { + // Isolate functions + Expression* dvar_dy = functions->dvar_dy_fcns[dvar_k]; + Expression* dvar_dp = functions->dvar_dp_fcns[dvar_k]; + // Calculate dvar/dy + (*dvar_dy)({&tval, y_val, functions->inputs.data()}, {&res_dvar_dy[0]}); + // Calculate dvar/dp and convert to dense array for indexing + (*dvar_dp)({&tval, y_val, functions->inputs.data()}, {&res_dvar_dp[0]}); + for (int k=0; knnz_out(); k++) { + dens_dvar_dp[dvar_dp->get_row()[k]] = res_dvar_dp[k]; + } + // Calculate sensitivities + for (int paramk=0; paramknnz_out(); spk++) { + yS_back_paramk[dvar_k] += res_dvar_dy[spk] * yS_val[paramk][dvar_dy->get_col()[spk]]; + } + } + } } template @@ -633,3 +833,48 @@ void IDAKLUSolverOpenMP::CheckErrors(int const & flag) { throw py::error_already_set(); } } + +template +void IDAKLUSolverOpenMP::PrintStats() { + long nsteps, nrevals, nlinsetups, netfails; + int klast, kcur; + realtype hinused, hlast, hcur, tcur; + + CheckErrors(IDAGetIntegratorStats( + ida_mem, + &nsteps, + &nrevals, + &nlinsetups, + &netfails, + &klast, + &kcur, + &hinused, + &hlast, + &hcur, + &tcur + )); + + long nniters, nncfails; + CheckErrors(IDAGetNonlinSolvStats(ida_mem, &nniters, &nncfails)); + + long int ngevalsBBDP = 0; + if (setup_opts.using_iterative_solver) { + CheckErrors(IDABBDPrecGetNumGfnEvals(ida_mem, &ngevalsBBDP)); + } + + py::print("Solver Stats:"); + py::print("\tNumber of steps =", nsteps); + py::print("\tNumber of calls to residual function =", nrevals); + py::print("\tNumber of calls to residual function in preconditioner =", + ngevalsBBDP); + py::print("\tNumber of linear solver setup calls =", nlinsetups); + py::print("\tNumber of error test failures =", netfails); + py::print("\tMethod order used on last step =", klast); + py::print("\tMethod order used on next step =", kcur); + py::print("\tInitial step size =", hinused); + py::print("\tStep size on last step =", hlast); + py::print("\tStep size on next step =", hcur); + py::print("\tCurrent internal time reached =", tcur); + py::print("\tNumber of nonlinear iterations performed =", nniters); + py::print("\tNumber of nonlinear convergence failures =", nncfails); +} diff --git a/src/pybamm/solvers/c_solvers/idaklu/Solution.hpp b/src/pybamm/solvers/c_solvers/idaklu/Solution.hpp index ad0ea06762..72d48fa644 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/Solution.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/Solution.hpp @@ -12,7 +12,7 @@ class Solution /** * @brief Constructor */ - Solution(int retval, np_array t_np, np_array y_np, np_array yS_np, np_array y_term_np) + Solution(int &retval, np_array &t_np, np_array &y_np, np_array &yS_np, np_array &y_term_np) : flag(retval), t(t_np), y(y_np), yS(yS_np), y_term(y_term_np) { } diff --git a/src/pybamm/solvers/c_solvers/idaklu/common.cpp b/src/pybamm/solvers/c_solvers/idaklu/common.cpp new file mode 100644 index 0000000000..bf38acc56a --- /dev/null +++ b/src/pybamm/solvers/c_solvers/idaklu/common.cpp @@ -0,0 +1,31 @@ +#include "common.hpp" + +std::vector numpy2realtype(const np_array& input_np) { + std::vector output(input_np.request().size); + + auto const inputData = input_np.unchecked<1>(); + for (int i = 0; i < output.size(); i++) { + output[i] = inputData[i]; + } + + return output; +} + +std::vector setDiff(const std::vector& A, const std::vector& B) { + std::vector result; + if (!(A.empty())) { + std::set_difference(A.begin(), A.end(), B.begin(), B.end(), std::back_inserter(result)); + } + return result; +} + +std::vector makeSortedUnique(const std::vector& input) { + std::unordered_set uniqueSet(input.begin(), input.end()); // Remove duplicates + std::vector uniqueVector(uniqueSet.begin(), uniqueSet.end()); // Convert to vector + std::sort(uniqueVector.begin(), uniqueVector.end()); // Sort the vector + return uniqueVector; +} + +std::vector makeSortedUnique(const np_array& input_np) { + return makeSortedUnique(numpy2realtype(input_np)); +} diff --git a/src/pybamm/solvers/c_solvers/idaklu/common.hpp b/src/pybamm/solvers/c_solvers/idaklu/common.hpp index 0ef7ee60a0..3289326541 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/common.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/common.hpp @@ -74,6 +74,24 @@ void csc_csr(const realtype f[], const T1 c[], const T1 r[], realtype nf[], T2 n } } + +/** + * @brief Utility function to convert numpy array to std::vector + */ +std::vector numpy2realtype(const np_array& input_np); + +/** + * @brief Utility function to compute the set difference of two vectors + */ +std::vector setDiff(const std::vector& A, const std::vector& B); + +/** + * @brief Utility function to make a sorted and unique vector + */ +std::vector makeSortedUnique(const std::vector& input); + +std::vector makeSortedUnique(const np_array& input_np); + #ifdef NDEBUG #define DEBUG_VECTOR(vector) #define DEBUG_VECTORn(vector, N) diff --git a/src/pybamm/solvers/casadi_algebraic_solver.py b/src/pybamm/solvers/casadi_algebraic_solver.py index 635adb5d34..cf44912952 100644 --- a/src/pybamm/solvers/casadi_algebraic_solver.py +++ b/src/pybamm/solvers/casadi_algebraic_solver.py @@ -25,7 +25,7 @@ def __init__(self, tol=1e-6, extra_options=None): super().__init__() self.tol = tol self.name = "CasADi algebraic solver" - self.algebraic_solver = True + self._algebraic_solver = True self.extra_options = extra_options or {} pybamm.citations.register("Andersson2019") @@ -37,7 +37,7 @@ def tol(self): def tol(self, value): self._tol = value - def _integrate(self, model, t_eval, inputs_dict=None): + def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): """ Calculate the solution of the algebraic equations through root-finding diff --git a/src/pybamm/solvers/casadi_solver.py b/src/pybamm/solvers/casadi_solver.py index f67a2decb4..b4ac9d1561 100644 --- a/src/pybamm/solvers/casadi_solver.py +++ b/src/pybamm/solvers/casadi_solver.py @@ -133,7 +133,7 @@ def __init__( pybamm.citations.register("Andersson2019") - def _integrate(self, model, t_eval, inputs_dict=None): + def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): """ Solve a DAE model defined by residuals with initial conditions y0. diff --git a/src/pybamm/solvers/dummy_solver.py b/src/pybamm/solvers/dummy_solver.py index c98667293d..45a6a332b1 100644 --- a/src/pybamm/solvers/dummy_solver.py +++ b/src/pybamm/solvers/dummy_solver.py @@ -12,7 +12,7 @@ def __init__(self): super().__init__() self.name = "Dummy solver" - def _integrate(self, model, t_eval, inputs_dict=None): + def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): """ Solve an empty model. diff --git a/src/pybamm/solvers/idaklu_jax.py b/src/pybamm/solvers/idaklu_jax.py index ba1c80c1c4..5a73d42c6e 100644 --- a/src/pybamm/solvers/idaklu_jax.py +++ b/src/pybamm/solvers/idaklu_jax.py @@ -55,6 +55,7 @@ def __init__( t_eval, output_variables=None, calculate_sensitivities=True, + t_interp=None, ): if not pybamm.have_jax(): raise ModuleNotFoundError( @@ -77,6 +78,7 @@ def __init__( t_eval, output_variables=output_variables, calculate_sensitivities=calculate_sensitivities, + t_interp=t_interp, ) def get_jaxpr(self): @@ -355,16 +357,15 @@ def _jaxify_solve(self, t, invar, *inputs_values): fo reuse. """ # Reconstruct dictionary of inputs - if self.jax_inputs is None: - d = self._hashabledict() - else: + d = self._hashabledict() + if self.jax_inputs is not None: # Use hashable dictionaries for caching the solve - d = self._hashabledict() for key, value in zip(self.jax_inputs.keys(), inputs_values): d[key] = value # Solver logger.debug("_jaxify_solve:") logger.debug(f" t_eval: {self.jax_t_eval}") + logger.debug(f" t_interp: {self.jax_t_interp}") logger.debug(f" t: {t}") logger.debug(f" invar: {invar}") logger.debug(f" inputs: {dict(d)}") @@ -375,6 +376,7 @@ def _jaxify_solve(self, t, invar, *inputs_values): tuple(self.jax_t_eval), inputs=self._hashabledict(d), calculate_sensitivities=self.jax_calculate_sensitivities, + t_interp=tuple(self.jax_t_interp), ) if invar is not None: if isinstance(invar, numbers.Number): @@ -549,6 +551,7 @@ def jaxify( *, output_variables=None, calculate_sensitivities=True, + t_interp=None, ): """JAXify the model and solver @@ -560,12 +563,14 @@ def jaxify( model : :class:`pybamm.BaseModel` The model to be solved t_eval : numeric type, optional - The times at which to compute the solution. If None, the times in the model - are used. + The times at which to stop the integration due to a discontinuity in time. output_variables : list of str, optional The variables to be returned. If None, the variables in the model are used. calculate_sensitivities : bool, optional Whether to calculate sensitivities. Default is True. + t_interp : None, list or ndarray, optional + The times (in seconds) at which to interpolate the solution. Defaults to None. + Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`). """ if self.jaxpr is not None: warnings.warn( @@ -579,6 +584,7 @@ def jaxify( t_eval, output_variables=output_variables, calculate_sensitivities=calculate_sensitivities, + t_interp=t_interp, ) return self.jaxpr @@ -589,11 +595,15 @@ def _jaxify( *, output_variables=None, calculate_sensitivities=True, + t_interp=None, ): """JAXify the model and solver""" self.jax_model = model self.jax_t_eval = t_eval + if t_interp is None: + t_interp = np.empty(0) + self.jax_t_interp = t_interp self.jax_output_variables = ( output_variables if output_variables else self.solver.output_variables ) diff --git a/src/pybamm/solvers/idaklu_solver.py b/src/pybamm/solvers/idaklu_solver.py index 34f7c1abaa..85731f4e12 100644 --- a/src/pybamm/solvers/idaklu_solver.py +++ b/src/pybamm/solvers/idaklu_solver.py @@ -13,6 +13,7 @@ import importlib import warnings + if pybamm.have_jax(): import jax from jax import numpy as jnp @@ -232,6 +233,7 @@ def __init__( output_variables, ) self.name = "IDA KLU solver" + self._supports_interp = True pybamm.citations.register("Hindmarsh2000") pybamm.citations.register("Hindmarsh2005") @@ -828,7 +830,7 @@ def _check_mlir_conversion(self, name, mlir: str): def _demote_64_to_32(self, x: pybamm.EvaluatorJax): return pybamm.EvaluatorJax._demote_64_to_32(x) - def _integrate(self, model, t_eval, inputs_dict=None): + def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): """ Solve a DAE model defined by residuals with initial conditions y0. @@ -837,9 +839,12 @@ def _integrate(self, model, t_eval, inputs_dict=None): model : :class:`pybamm.BaseModel` The model whose solution to calculate. t_eval : numeric type - The times at which to compute the solution + The times at which to stop the integration due to a discontinuity in time. inputs_dict : dict, optional Any input parameters to pass to the model when solving. + t_interp : None, list or ndarray, optional + The times (in seconds) at which to interpolate the solution. Defaults to `None`, + which returns the adaptive time-stepping times. """ inputs_dict = inputs_dict or {} # stack inputs @@ -863,6 +868,7 @@ def _integrate(self, model, t_eval, inputs_dict=None): ): sol = self._setup["solver"].solve( t_eval, + t_interp, y0full, ydot0full, inputs, @@ -919,13 +925,13 @@ def _integrate(self, model, t_eval, inputs_dict=None): yS_out = False # 0 = solved for all t_eval - if sol.flag == 0: - termination = "final time" # 2 = found root(s) - elif sol.flag == 2: + if sol.flag == 2: termination = "event" + elif sol.flag >= 0: + termination = "final time" else: - raise pybamm.SolverError("idaklu solver failed") + raise pybamm.SolverError(f"FAILURE {self._solver_flag(sol.flag)}") newsol = pybamm.Solution( sol.t, @@ -939,6 +945,7 @@ def _integrate(self, model, t_eval, inputs_dict=None): ) newsol.integration_time = integration_time if not self.output_variables: + # print((newsol.y).shape) return newsol # Populate variables and sensititivies dictionaries directly @@ -1138,6 +1145,7 @@ def jaxify( *, output_variables=None, calculate_sensitivities=True, + t_interp=None, ): """JAXify the solver object @@ -1149,12 +1157,14 @@ def jaxify( model : :class:`pybamm.BaseModel` The model to be solved t_eval : numeric type, optional - The times at which to compute the solution. If None, the times in the model - are used. + The times at which to stop the integration due to a discontinuity in time. output_variables : list of str, optional The variables to be returned. If None, all variables in the model are used. calculate_sensitivities : bool, optional Whether to calculate sensitivities. Default is True. + t_interp : None, list or ndarray, optional + The times (in seconds) at which to interpolate the solution. Defaults to `None`, + which returns the adaptive time-stepping times. """ obj = pybamm.IDAKLUJax( self, # IDAKLU solver instance @@ -1162,5 +1172,43 @@ def jaxify( t_eval, output_variables=output_variables, calculate_sensitivities=calculate_sensitivities, + t_interp=t_interp, ) return obj + + @staticmethod + def _solver_flag(flag): + flags = { + 99: "IDA_WARNING: IDASolve succeeded but an unusual situation occurred.", + 2: "IDA_ROOT_RETURN: IDASolve succeeded and found one or more roots.", + 1: "IDA_TSTOP_RETURN: IDASolve succeeded by reaching the specified stopping point.", + 0: "IDA_SUCCESS: Successful function return.", + -1: "IDA_TOO_MUCH_WORK: The solver took mxstep internal steps but could not reach tout.", + -2: "IDA_TOO_MUCH_ACC: The solver could not satisfy the accuracy demanded by the user for some internal step.", + -3: "IDA_ERR_FAIL: Error test failures occurred too many times during one internal time step or minimum step size was reached.", + -4: "IDA_CONV_FAIL: Convergence test failures occurred too many times during one internal time step or minimum step size was reached.", + -5: "IDA_LINIT_FAIL: The linear solver's initialization function failed.", + -6: "IDA_LSETUP_FAIL: The linear solver's setup function failed in an unrecoverable manner.", + -7: "IDA_LSOLVE_FAIL: The linear solver's solve function failed in an unrecoverable manner.", + -8: "IDA_RES_FAIL: The user-provided residual function failed in an unrecoverable manner.", + -9: "IDA_REP_RES_FAIL: The user-provided residual function repeatedly returned a recoverable error flag, but the solver was unable to recover.", + -10: "IDA_RTFUNC_FAIL: The rootfinding function failed in an unrecoverable manner.", + -11: "IDA_CONSTR_FAIL: The inequality constraints were violated and the solver was unable to recover.", + -12: "IDA_FIRST_RES_FAIL: The user-provided residual function failed recoverably on the first call.", + -13: "IDA_LINESEARCH_FAIL: The line search failed.", + -14: "IDA_NO_RECOVERY: The residual function, linear solver setup function, or linear solver solve function had a recoverable failure, but IDACalcIC could not recover.", + -15: "IDA_NLS_INIT_FAIL: The nonlinear solver's init routine failed.", + -16: "IDA_NLS_SETUP_FAIL: The nonlinear solver's setup routine failed.", + -20: "IDA_MEM_NULL: The ida mem argument was NULL.", + -21: "IDA_MEM_FAIL: A memory allocation failed.", + -22: "IDA_ILL_INPUT: One of the function inputs is illegal.", + -23: "IDA_NO_MALLOC: The ida memory was not allocated by a call to IDAInit.", + -24: "IDA_BAD_EWT: Zero value of some error weight component.", + -25: "IDA_BAD_K: The k-th derivative is not available.", + -26: "IDA_BAD_T: The time t is outside the last step taken.", + -27: "IDA_BAD_DKY: The vector argument where derivative should be stored is NULL.", + } + + flag_unknown = "Unknown IDA flag." + + return flags.get(flag, flag_unknown) diff --git a/src/pybamm/solvers/jax_solver.py b/src/pybamm/solvers/jax_solver.py index fbe047b3cc..26a069e0fe 100644 --- a/src/pybamm/solvers/jax_solver.py +++ b/src/pybamm/solvers/jax_solver.py @@ -72,9 +72,7 @@ def __init__( method_options = ["RK45", "BDF"] if method not in method_options: raise ValueError(f"method must be one of {method_options}") - self.ode_solver = False - if method == "RK45": - self.ode_solver = True + self._ode_solver = method == "RK45" self.extra_options = extra_options or {} self.name = f"JAX solver ({method})" self._cached_solves = dict() @@ -187,7 +185,7 @@ def solve_model_bdf(inputs): else: return jax.jit(solve_model_bdf) - def _integrate(self, model, t_eval, inputs=None): + def _integrate(self, model, t_eval, inputs=None, t_interp=None): """ Solve a model defined by dydt with initial conditions y0. diff --git a/src/pybamm/solvers/processed_variable.py b/src/pybamm/solvers/processed_variable.py index 38314ca5c2..8c1190c2f4 100644 --- a/src/pybamm/solvers/processed_variable.py +++ b/src/pybamm/solvers/processed_variable.py @@ -75,43 +75,42 @@ def __init__( self.base_eval_shape = self.base_variables[0].shape self.base_eval_size = self.base_variables[0].size + # xr_data_array is initialized + self._xr_data_array = None + # handle 2D (in space) finite element variables differently if ( self.mesh and "current collector" in self.domain and isinstance(self.mesh, pybamm.ScikitSubMesh2D) ): - self.initialise_2D_scikit_fem() + return self.initialise_2D_scikit_fem() # check variable shape - else: - if len(self.base_eval_shape) == 0 or self.base_eval_shape[0] == 1: - self.initialise_0D() - else: - n = self.mesh.npts - base_shape = self.base_eval_shape[0] - # Try some shapes that could make the variable a 1D variable - if base_shape in [n, n + 1]: - self.initialise_1D() - else: - # Try some shapes that could make the variable a 2D variable - first_dim_nodes = self.mesh.nodes - first_dim_edges = self.mesh.edges - second_dim_pts = self.base_variables[0].secondary_mesh.nodes - if self.base_eval_size // len(second_dim_pts) in [ - len(first_dim_nodes), - len(first_dim_edges), - ]: - self.initialise_2D() - else: - # Raise error for 3D variable - raise NotImplementedError( - f"Shape not recognized for {base_variables[0]}" - + "(note processing of 3D variables is not yet implemented)" - ) - - # xr_data_array is initialized when needed - self._xr_data_array = None + if len(self.base_eval_shape) == 0 or self.base_eval_shape[0] == 1: + return self.initialise_0D() + + n = self.mesh.npts + base_shape = self.base_eval_shape[0] + # Try some shapes that could make the variable a 1D variable + if base_shape in [n, n + 1]: + return self.initialise_1D() + + # Try some shapes that could make the variable a 2D variable + first_dim_nodes = self.mesh.nodes + first_dim_edges = self.mesh.edges + second_dim_pts = self.base_variables[0].secondary_mesh.nodes + if self.base_eval_size // len(second_dim_pts) in [ + len(first_dim_nodes), + len(first_dim_edges), + ]: + return self.initialise_2D() + + # Raise error for 3D variable + raise NotImplementedError( + f"Shape not recognized for {base_variables[0]}" + + "(note processing of 3D variables is not yet implemented)" + ) def initialise_0D(self): # initialise empty array of the correct size diff --git a/src/pybamm/solvers/scipy_solver.py b/src/pybamm/solvers/scipy_solver.py index 9a66f5bc01..226b096887 100644 --- a/src/pybamm/solvers/scipy_solver.py +++ b/src/pybamm/solvers/scipy_solver.py @@ -42,12 +42,12 @@ def __init__( atol=atol, extrap_tol=extrap_tol, ) - self.ode_solver = True + self._ode_solver = True self.extra_options = extra_options or {} self.name = f"Scipy solver ({method})" pybamm.citations.register("Virtanen2020") - def _integrate(self, model, t_eval, inputs_dict=None): + def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): """ Solve a model defined by dydt with initial conditions y0. diff --git a/tests/integration/test_models/standard_model_tests.py b/tests/integration/test_models/standard_model_tests.py index 3f9cb56354..a962894c44 100644 --- a/tests/integration/test_models/standard_model_tests.py +++ b/tests/integration/test_models/standard_model_tests.py @@ -104,7 +104,8 @@ def test_sensitivities(self, param_name, param_value, output_name="Voltage [V]") self.parameter_values["Current function [A]"] / self.parameter_values["Nominal cell capacity [A.h]"] ) - t_eval = np.linspace(0, 3600 / Crate, 100) + t_interp = np.linspace(0, 3600 / Crate, 100) + t_eval = np.array([t_interp[0], t_interp[-1]]) # make param_name an input self.parameter_values.update({param_name: "[input]"}) @@ -118,7 +119,11 @@ def test_sensitivities(self, param_name, param_value, output_name="Voltage [V]") self.solver.atol = 1e-8 self.solution = self.solver.solve( - self.model, t_eval, inputs=inputs, calculate_sensitivities=True + self.model, + t_eval, + inputs=inputs, + calculate_sensitivities=True, + t_interp=t_interp, ) output_sens = self.solution[output_name].sensitivities[param_name] @@ -126,10 +131,14 @@ def test_sensitivities(self, param_name, param_value, output_name="Voltage [V]") h = 1e-2 * param_value inputs_plus = {param_name: (param_value + 0.5 * h)} inputs_neg = {param_name: (param_value - 0.5 * h)} - sol_plus = self.solver.solve(self.model, t_eval, inputs=inputs_plus) - output_plus = sol_plus[output_name](t=t_eval) - sol_neg = self.solver.solve(self.model, t_eval, inputs=inputs_neg) - output_neg = sol_neg[output_name](t=t_eval) + sol_plus = self.solver.solve( + self.model, t_eval, inputs=inputs_plus, t_interp=t_interp + ) + output_plus = sol_plus[output_name].data + sol_neg = self.solver.solve( + self.model, t_eval, inputs=inputs_neg, t_interp=t_interp + ) + output_neg = sol_neg[output_name].data fd = (np.array(output_plus) - np.array(output_neg)) / h fd = fd.transpose().reshape(-1, 1) np.testing.assert_allclose( diff --git a/tests/integration/test_solvers/test_idaklu.py b/tests/integration/test_solvers/test_idaklu.py index 31083319cf..abc7741c0c 100644 --- a/tests/integration/test_solvers/test_idaklu.py +++ b/tests/integration/test_solvers/test_idaklu.py @@ -31,13 +31,15 @@ def test_on_spme_sensitivities(self): mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts) disc = pybamm.Discretisation(mesh, model.default_spatial_methods) disc.process_model(model) - t_eval = np.linspace(0, 3500, 100) + t_interp = np.linspace(0, 3500, 100) + t_eval = np.array([t_interp[0], t_interp[-1]]) solver = pybamm.IDAKLUSolver(rtol=1e-10, atol=1e-10) solution = solver.solve( model, t_eval, inputs=inputs, calculate_sensitivities=True, + t_interp=t_interp, ) np.testing.assert_array_less(1, solution.t.size) @@ -47,10 +49,10 @@ def test_on_spme_sensitivities(self): # evaluate the sensitivities using finite difference h = 1e-5 sol_plus = solver.solve( - model, t_eval, inputs={param_name: param_value + 0.5 * h} + model, t_eval, inputs={param_name: param_value + 0.5 * h}, t_interp=t_interp ) sol_neg = solver.solve( - model, t_eval, inputs={param_name: param_value - 0.5 * h} + model, t_eval, inputs={param_name: param_value - 0.5 * h}, t_interp=t_interp ) dyda_fd = (sol_plus.y - sol_neg.y) / h dyda_fd = dyda_fd.transpose().reshape(-1, 1) @@ -87,3 +89,66 @@ def test_changing_grid(self): # solve solver.solve(model_disc, t_eval) + + def test_interpolation(self): + model = pybamm.BaseModel() + u1 = pybamm.Variable("u1") + u2 = pybamm.Variable("u2") + u3 = pybamm.Variable("u3") + v = pybamm.Variable("v") + a = pybamm.InputParameter("a") + b = pybamm.InputParameter("b", expected_size=2) + model.rhs = {u1: a * v, u2: pybamm.Index(b, 0), u3: pybamm.Index(b, 1)} + model.algebraic = {v: 1 - v} + model.initial_conditions = {u1: 0, u2: 0, u3: 0, v: 1} + + disc = pybamm.Discretisation() + model_disc = disc.process_model(model, inplace=False) + + a_value = 0.1 + b_value = np.array([[0.2], [0.3]]) + inputs = {"a": a_value, "b": b_value} + + # Calculate time for each solver and each number of grid points + t0 = 0 + tf = 3600 + t_eval_dense = np.linspace(t0, tf, 1000) + t_eval_sparse = [t0, tf] + + t_interp_dense = np.linspace(t0, tf, 800) + t_interp_sparse = [t0, tf] + solver = pybamm.IDAKLUSolver() + + # solve + # 1. dense t_eval + adaptive time stepping + sol1 = solver.solve(model_disc, t_eval_dense, inputs=inputs) + np.testing.assert_array_less(len(t_eval_dense), len(sol1.t)) + + # 2. sparse t_eval + adaptive time stepping + sol2 = solver.solve(model_disc, t_eval_sparse, inputs=inputs) + np.testing.assert_array_less(len(sol2.t), len(sol1.t)) + + # 3. dense t_eval + dense t_interp + sol3 = solver.solve( + model_disc, t_eval_dense, t_interp=t_interp_dense, inputs=inputs + ) + t_combined = np.concatenate((sol3.t, t_interp_dense)) + t_combined = np.unique(t_combined) + t_combined.sort() + np.testing.assert_array_almost_equal(sol3.t, t_combined) + + # 4. sparse t_eval + sparse t_interp + sol4 = solver.solve( + model_disc, t_eval_sparse, t_interp=t_interp_sparse, inputs=inputs + ) + np.testing.assert_array_almost_equal(sol4.t, np.array([t0, tf])) + + sols = [sol1, sol2, sol3, sol4] + for sol in sols: + # test that y[0] = to true solution + true_solution = a_value * sol.t + np.testing.assert_array_almost_equal(sol.y[0], true_solution) + + # test that y[1:3] = to true solution + true_solution = b_value * sol.t + np.testing.assert_array_almost_equal(sol.y[1:3], true_solution) diff --git a/tests/unit/test_experiments/test_experiment.py b/tests/unit/test_experiments/test_experiment.py index e445b05e31..2eda78e1e0 100644 --- a/tests/unit/test_experiments/test_experiment.py +++ b/tests/unit/test_experiments/test_experiment.py @@ -21,7 +21,7 @@ def test_cycle_unpacking(self): "value": 0.05, "type": "CRate", "duration": 1800.0, - "period": 60.0, + "period": None, "temperature": None, "description": "Discharge at C/20 for 0.5 hours", "termination": [], @@ -32,7 +32,7 @@ def test_cycle_unpacking(self): "value": -0.2, "type": "CRate", "duration": 2700.0, - "period": 60.0, + "period": None, "temperature": None, "description": "Charge at C/5 for 45 minutes", "termination": [], @@ -43,7 +43,7 @@ def test_cycle_unpacking(self): "value": 0.05, "type": "CRate", "duration": 1800.0, - "period": 60.0, + "period": None, "temperature": None, "description": "Discharge at C/20 for 0.5 hours", "termination": [], @@ -54,7 +54,7 @@ def test_cycle_unpacking(self): "value": -0.2, "type": "CRate", "duration": 2700.0, - "period": 60.0, + "period": None, "temperature": None, "description": "Charge at C/5 for 45 minutes", "termination": [], diff --git a/tests/unit/test_experiments/test_simulation_with_experiment.py b/tests/unit/test_experiments/test_simulation_with_experiment.py index defca33b00..c4f55889a1 100644 --- a/tests/unit/test_experiments/test_simulation_with_experiment.py +++ b/tests/unit/test_experiments/test_simulation_with_experiment.py @@ -191,12 +191,12 @@ def test_run_experiment_cccv_solvers(self): np.testing.assert_array_almost_equal( solutions[0]["Voltage [V]"].data, - solutions[1]["Voltage [V]"].data, + solutions[1]["Voltage [V]"](solutions[0].t), decimal=1, ) np.testing.assert_array_almost_equal( solutions[0]["Current [A]"].data, - solutions[1]["Current [A]"].data, + solutions[1]["Current [A]"](solutions[0].t), decimal=0, ) self.assertEqual(solutions[1].termination, "final time") diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index 884c85f87f..a35b864a64 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -114,7 +114,7 @@ def test_block_symbolic_inputs(self): ): solver.solve(model, np.array([1, 2, 3])) - def test_ode_solver_fail_with_dae(self): + def testode_solver_fail_with_dae(self): model = pybamm.BaseModel() a = pybamm.Scalar(1) model.algebraic = {a: a} diff --git a/tests/unit/test_solvers/test_casadi_solver.py b/tests/unit/test_solvers/test_casadi_solver.py index 5ce29a365d..eaaeedf0d0 100644 --- a/tests/unit/test_solvers/test_casadi_solver.py +++ b/tests/unit/test_solvers/test_casadi_solver.py @@ -1078,6 +1078,30 @@ def test_solve_sensitivity_subset(self): ), ) + def test_solver_interpolation_warning(self): + # Create model + model = pybamm.BaseModel() + domain = ["negative electrode", "separator", "positive electrode"] + var = pybamm.Variable("var", domain=domain) + model.rhs = {var: 0.1 * var} + model.initial_conditions = {var: 1} + # create discretisation + mesh = get_mesh_for_testing() + spatial_methods = {"macroscale": pybamm.FiniteVolume()} + disc = pybamm.Discretisation(mesh, spatial_methods) + disc.process_model(model) + + solver = pybamm.CasadiSolver() + + # Check for warning with t_interp + t_eval = np.linspace(0, 1, 10) + t_interp = t_eval + with self.assertWarns( + pybamm.SolverWarning, + msg=f"Explicit interpolation times not implemented for {solver.name}", + ): + solver.solve(model, t_eval, t_interp=t_interp) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_solvers/test_idaklu_jax.py b/tests/unit/test_solvers/test_idaklu_jax.py index 7bae5d74e9..d985991929 100644 --- a/tests/unit/test_solvers/test_idaklu_jax.py +++ b/tests/unit/test_solvers/test_idaklu_jax.py @@ -40,6 +40,7 @@ t_eval, inputs=inputs, calculate_sensitivities=True, + t_interp=t_eval, ) # Get jax expressions for IDAKLU solver @@ -54,6 +55,7 @@ t_eval, output_variables=output_variables[:1], calculate_sensitivities=True, + t_interp=t_eval, ) f1 = idaklu_jax_solver1.get_jaxpr() # Multiple output variables @@ -62,6 +64,7 @@ t_eval, output_variables=output_variables, calculate_sensitivities=True, + t_interp=t_eval, ) f3 = idaklu_jax_solver3.get_jaxpr() @@ -151,11 +154,12 @@ def test_no_inputs(self): t_eval = np.linspace(0, 1, 100) idaklu_solver = pybamm.IDAKLUSolver(rtol=1e-6, atol=1e-6) # Regenerate surrogate data - sim = idaklu_solver.solve(model, t_eval) + sim = idaklu_solver.solve(model, t_eval, t_interp=t_eval) idaklu_jax_solver = idaklu_solver.jaxify( model, t_eval, output_variables=output_variables, + t_interp=t_eval, ) f = idaklu_jax_solver.get_jaxpr() # Check that evaluation can occur (and is correct) with no inputs @@ -423,7 +427,6 @@ def test_jacrev_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): @parameterized.expand(testcase, skip_on_empty=True) def test_jacrev_vector(self, output_variables, idaklu_jax_solver, f, wrapper): - out = wrapper(jax.jacrev(f, argnums=1))(t_eval[k], inputs) out = wrapper(jax.jacrev(f, argnums=1))(t_eval, inputs) flat_out, _ = tree_flatten(out) flat_out = np.concatenate(np.array([f for f in flat_out]), 1).T.flatten() @@ -847,6 +850,7 @@ def sse(t, inputs): t_eval, inputs=inputs_pred, calculate_sensitivities=True, + t_interp=t_eval, ) pred = sim_pred["v"] diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index 33e50eaa7d..5bc845d66c 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -91,11 +91,26 @@ def test_model_events(self): root_method=root_method, options={"jax_evaluator": "iree"} if form == "iree" else {}, ) - t_eval = np.linspace(0, 1, 100) - solution = solver.solve(model_disc, t_eval) - np.testing.assert_array_equal(solution.t, t_eval) + + if model.convert_to_format == "casadi" or ( + model.convert_to_format == "jax" + and solver._options["jax_evaluator"] == "iree" + ): + t_interp = np.linspace(0, 1, 100) + t_eval = np.array([t_interp[0], t_interp[-1]]) + else: + t_eval = np.linspace(0, 1, 100) + t_interp = t_eval + + solution = solver.solve(model_disc, t_eval, t_interp=t_interp) + np.testing.assert_array_equal( + solution.t, t_interp, err_msg=f"Failed for form {form}" + ) np.testing.assert_array_almost_equal( - solution.y[0], np.exp(0.1 * solution.t), decimal=5 + solution.y[0], + np.exp(0.1 * solution.t), + decimal=5, + err_msg=f"Failed for form {form}", ) # Check invalid atol type raises an error @@ -111,10 +126,13 @@ def test_model_events(self): root_method=root_method, options={"jax_evaluator": "iree"} if form == "iree" else {}, ) - solution = solver.solve(model_disc, t_eval) - np.testing.assert_array_equal(solution.t, t_eval) + solution = solver.solve(model_disc, t_eval, t_interp=t_interp) + np.testing.assert_array_equal(solution.t, t_interp) np.testing.assert_array_almost_equal( - solution.y[0], np.exp(0.1 * solution.t), decimal=5 + solution.y[0], + np.exp(0.1 * solution.t), + decimal=5, + err_msg=f"Failed for form {form}", ) # enforce events that will be triggered @@ -126,10 +144,13 @@ def test_model_events(self): root_method=root_method, options={"jax_evaluator": "iree"} if form == "iree" else {}, ) - solution = solver.solve(model_disc, t_eval) - self.assertLess(len(solution.t), len(t_eval)) + solution = solver.solve(model_disc, t_eval, t_interp=t_interp) + self.assertLess(len(solution.t), len(t_interp)) np.testing.assert_array_almost_equal( - solution.y[0], np.exp(0.1 * solution.t), decimal=5 + solution.y[0], + np.exp(0.1 * solution.t), + decimal=5, + err_msg=f"Failed for form {form}", ) # bigger dae model with multiple events @@ -153,17 +174,23 @@ def test_model_events(self): root_method=root_method, options={"jax_evaluator": "iree"} if form == "iree" else {}, ) - t_eval = np.linspace(0, 5, 100) + t_eval = np.array([0, 5]) solution = solver.solve(model, t_eval) np.testing.assert_array_less(solution.y[0, :-1], 1.5) np.testing.assert_array_less(solution.y[-1, :-1], 2.5) np.testing.assert_equal(solution.t_event[0], solution.t[-1]) np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1]) np.testing.assert_array_almost_equal( - solution.y[0], np.exp(0.1 * solution.t), decimal=5 + solution.y[0], + np.exp(0.1 * solution.t), + decimal=5, + err_msg=f"Failed for form {form}", ) np.testing.assert_array_almost_equal( - solution.y[-1], 2 * np.exp(0.1 * solution.t), decimal=5 + solution.y[-1], + 2 * np.exp(0.1 * solution.t), + decimal=5, + err_msg=f"Failed for form {form}", ) def test_input_params(self): @@ -208,15 +235,21 @@ def test_input_params(self): ) # test that y[3] remains constant - np.testing.assert_array_almost_equal(sol.y[3], np.ones(sol.t.shape)) + np.testing.assert_array_almost_equal( + sol.y[3], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" + ) # test that y[0] = to true solution true_solution = a_value * sol.t - np.testing.assert_array_almost_equal(sol.y[0], true_solution) + np.testing.assert_array_almost_equal( + sol.y[0], true_solution, err_msg=f"Failed for form {form}" + ) # test that y[1:3] = to true solution true_solution = b_value * sol.t - np.testing.assert_array_almost_equal(sol.y[1:3], true_solution) + np.testing.assert_array_almost_equal( + sol.y[1:3], true_solution, err_msg=f"Failed for form {form}" + ) def test_sensitivities_initial_condition(self): for form in ["casadi", "iree"]: @@ -249,17 +282,24 @@ def test_sensitivities_initial_condition(self): options={"jax_evaluator": "iree"} if form == "iree" else {}, ) - t_eval = np.linspace(0, 3, 100) + t_interp = np.linspace(0, 3, 100) + t_eval = np.array([t_interp[0], t_interp[-1]]) + a_value = 0.1 sol = solver.solve( - model, t_eval, inputs={"a": a_value}, calculate_sensitivities=True + model, + t_eval, + inputs={"a": a_value}, + calculate_sensitivities=True, + t_interp=t_interp, ) np.testing.assert_array_almost_equal( sol["2v"].sensitivities["a"].full().flatten(), np.exp(-sol.t) * 2, decimal=4, + err_msg=f"Failed for form {form}", ) def test_ida_roberts_klu_sensitivities(self): @@ -293,7 +333,8 @@ def test_ida_roberts_klu_sensitivities(self): options={"jax_evaluator": "iree"} if form == "iree" else {}, ) - t_eval = np.linspace(0, 3, 100) + t_interp = np.linspace(0, 3, 100) + t_eval = np.array([t_interp[0], t_interp[-1]]) a_value = 0.1 # solve first without sensitivities @@ -301,14 +342,19 @@ def test_ida_roberts_klu_sensitivities(self): model, t_eval, inputs={"a": a_value}, + t_interp=t_interp, ) # test that y[1] remains constant - np.testing.assert_array_almost_equal(sol.y[1, :], np.ones(sol.t.shape)) + np.testing.assert_array_almost_equal( + sol.y[1, :], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" + ) # test that y[0] = to true solution true_solution = a_value * sol.t - np.testing.assert_array_almost_equal(sol.y[0, :], true_solution) + np.testing.assert_array_almost_equal( + sol.y[0, :], true_solution, err_msg=f"Failed for form {form}" + ) # should be no sensitivities calculated with self.assertRaises(KeyError): @@ -316,35 +362,52 @@ def test_ida_roberts_klu_sensitivities(self): # now solve with sensitivities (this should cause set_up to be run again) sol = solver.solve( - model, t_eval, inputs={"a": a_value}, calculate_sensitivities=True + model, + t_eval, + inputs={"a": a_value}, + calculate_sensitivities=True, + t_interp=t_interp, ) # test that y[1] remains constant - np.testing.assert_array_almost_equal(sol.y[1, :], np.ones(sol.t.shape)) + np.testing.assert_array_almost_equal( + sol.y[1, :], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" + ) # test that y[0] = to true solution true_solution = a_value * sol.t - np.testing.assert_array_almost_equal(sol.y[0, :], true_solution) + np.testing.assert_array_almost_equal( + sol.y[0, :], true_solution, err_msg=f"Failed for form {form}" + ) # evaluate the sensitivities using idas dyda_ida = sol.sensitivities["a"] # evaluate the sensitivities using finite difference h = 1e-6 - sol_plus = solver.solve(model, t_eval, inputs={"a": a_value + 0.5 * h}) - sol_neg = solver.solve(model, t_eval, inputs={"a": a_value - 0.5 * h}) + sol_plus = solver.solve( + model, t_eval, inputs={"a": a_value + 0.5 * h}, t_interp=t_interp + ) + sol_neg = solver.solve( + model, t_eval, inputs={"a": a_value - 0.5 * h}, t_interp=t_interp + ) dyda_fd = (sol_plus.y - sol_neg.y) / h dyda_fd = dyda_fd.transpose().reshape(-1, 1) decimal = ( 2 if form == "iree" else 6 ) # iree currently operates with single precision - np.testing.assert_array_almost_equal(dyda_ida, dyda_fd, decimal=decimal) + np.testing.assert_array_almost_equal( + dyda_ida, dyda_fd, decimal=decimal, err_msg=f"Failed for form {form}" + ) # get the sensitivities for the variable d2uda = sol["2u"].sensitivities["a"] np.testing.assert_array_almost_equal( - 2 * dyda_ida[0:200:2], d2uda, decimal=decimal + 2 * dyda_ida[0:200:2], + d2uda, + decimal=decimal, + err_msg=f"Failed for form {form}", ) def test_ida_roberts_consistent_initialization(self): @@ -382,10 +445,14 @@ def test_ida_roberts_consistent_initialization(self): solver._set_consistent_initialization(model, t0, inputs_dict={}) # u(t0) = 0, v(t0) = 1 - np.testing.assert_array_almost_equal(model.y0full, [0, 1]) + np.testing.assert_array_almost_equal( + model.y0full, [0, 1], err_msg=f"Failed for form {form}" + ) # u'(t0) = 0.1 * v(t0) = 0.1 # Since v is algebraic, the initial derivative is set to 0 - np.testing.assert_array_almost_equal(model.ydot0full, [0.1, 0]) + np.testing.assert_array_almost_equal( + model.ydot0full, [0.1, 0], err_msg=f"Failed for form {form}" + ) def test_sensitivities_with_events(self): # this test implements a python version of the ida Roberts @@ -419,7 +486,9 @@ def test_sensitivities_with_events(self): options={"jax_evaluator": "iree"} if form == "iree" else {}, ) - t_eval = np.linspace(0, 3, 100) + t_interp = np.linspace(0, 3, 100) + t_eval = np.array([t_interp[0], t_interp[-1]]) + a_value = 0.1 b_value = 0.0 @@ -429,14 +498,19 @@ def test_sensitivities_with_events(self): t_eval, inputs={"a": a_value, "b": b_value}, calculate_sensitivities=True, + t_interp=t_interp, ) # test that y[1] remains constant - np.testing.assert_array_almost_equal(sol.y[1, :], np.ones(sol.t.shape)) + np.testing.assert_array_almost_equal( + sol.y[1, :], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" + ) # test that y[0] = to true solution true_solution = a_value * sol.t - np.testing.assert_array_almost_equal(sol.y[0, :], true_solution) + np.testing.assert_array_almost_equal( + sol.y[0, :], true_solution, err_msg=f"Failed for form {form}" + ) # evaluate the sensitivities using idas dyda_ida = sol.sensitivities["a"] @@ -445,10 +519,16 @@ def test_sensitivities_with_events(self): # evaluate the sensitivities using finite difference h = 1e-6 sol_plus = solver.solve( - model, t_eval, inputs={"a": a_value + 0.5 * h, "b": b_value} + model, + t_eval, + inputs={"a": a_value + 0.5 * h, "b": b_value}, + t_interp=t_interp, ) sol_neg = solver.solve( - model, t_eval, inputs={"a": a_value - 0.5 * h, "b": b_value} + model, + t_eval, + inputs={"a": a_value - 0.5 * h, "b": b_value}, + t_interp=t_interp, ) max_index = min(sol_plus.y.shape[1], sol_neg.y.shape[1]) - 1 dyda_fd = (sol_plus.y[:, :max_index] - sol_neg.y[:, :max_index]) / h @@ -458,21 +538,33 @@ def test_sensitivities_with_events(self): 2 if form == "iree" else 6 ) # iree currently operates with single precision np.testing.assert_array_almost_equal( - dyda_ida[: (2 * max_index), :], dyda_fd, decimal=decimal + dyda_ida[: (2 * max_index), :], + dyda_fd, + decimal=decimal, + err_msg=f"Failed for form {form}", ) sol_plus = solver.solve( - model, t_eval, inputs={"a": a_value, "b": b_value + 0.5 * h} + model, + t_eval, + inputs={"a": a_value, "b": b_value + 0.5 * h}, + t_interp=t_interp, ) sol_neg = solver.solve( - model, t_eval, inputs={"a": a_value, "b": b_value - 0.5 * h} + model, + t_eval, + inputs={"a": a_value, "b": b_value - 0.5 * h}, + t_interp=t_interp, ) max_index = min(sol_plus.y.shape[1], sol_neg.y.shape[1]) - 1 dydb_fd = (sol_plus.y[:, :max_index] - sol_neg.y[:, :max_index]) / h dydb_fd = dydb_fd.transpose().reshape(-1, 1) np.testing.assert_array_almost_equal( - dydb_ida[: (2 * max_index), :], dydb_fd, decimal=decimal + dydb_ida[: (2 * max_index), :], + dydb_fd, + decimal=decimal, + err_msg=f"Failed for form {form}", ) def test_failures(self): @@ -523,7 +615,7 @@ def test_failures(self): solver = pybamm.IDAKLUSolver() t_eval = np.linspace(0, 3, 100) - with self.assertRaisesRegex(pybamm.SolverError, "idaklu solver failed"): + with self.assertRaisesRegex(pybamm.SolverError, "FAILURE IDA"): solver.solve(model, t_eval) def test_dae_solver_algebraic_model(self): @@ -592,22 +684,23 @@ def test_setup_options(self): disc = pybamm.Discretisation() disc.process_model(model) - t_eval = np.linspace(0, 1) + t_interp = np.linspace(0, 1) + t_eval = np.array([t_interp[0], t_interp[-1]]) solver = pybamm.IDAKLUSolver() - soln_base = solver.solve(model, t_eval) + soln_base = solver.solve(model, t_eval, t_interp=t_interp) # test print_stats solver = pybamm.IDAKLUSolver(options={"print_stats": True}) f = io.StringIO() with redirect_stdout(f): - solver.solve(model, t_eval) + solver.solve(model, t_eval, t_interp=t_interp) s = f.getvalue() self.assertIn("Solver Stats", s) solver = pybamm.IDAKLUSolver(options={"print_stats": False}) f = io.StringIO() with redirect_stdout(f): - solver.solve(model, t_eval) + solver.solve(model, t_eval, t_interp=t_interp) s = f.getvalue() self.assertEqual(len(s), 0) @@ -634,7 +727,7 @@ def test_setup_options(self): rtol=1e-8, options=options, ) - if ( + works = ( jacobian == "none" and (linear_solver == "SUNLinSol_Dense") or jacobian == "dense" @@ -650,14 +743,11 @@ def test_setup_options(self): and linear_solver != "SUNLinSol_Dense" and linear_solver != "garbage" ) - ): - works = True - else: - works = False + ) if works: - soln = solver.solve(model, t_eval) - np.testing.assert_array_almost_equal(soln.y, soln_base.y, 5) + soln = solver.solve(model, t_eval, t_interp=t_interp) + np.testing.assert_array_almost_equal(soln.y, soln_base.y, 4) else: with self.assertRaises(ValueError): soln = solver.solve(model, t_eval) @@ -672,9 +762,10 @@ def test_solver_options(self): disc = pybamm.Discretisation() disc.process_model(model) - t_eval = np.linspace(0, 1) + t_interp = np.linspace(0, 1) + t_eval = np.array([t_interp[0], t_interp[-1]]) solver = pybamm.IDAKLUSolver() - soln_base = solver.solve(model, t_eval) + soln_base = solver.solve(model, t_eval, t_interp=t_interp) options_success = { "max_order_bdf": 4, @@ -704,9 +795,9 @@ def test_solver_options(self): for option in options_success: options = {option: options_success[option]} solver = pybamm.IDAKLUSolver(rtol=1e-6, atol=1e-6, options=options) - soln = solver.solve(model, t_eval) + soln = solver.solve(model, t_eval, t_interp=t_interp) - np.testing.assert_array_almost_equal(soln.y, soln_base.y, 5) + np.testing.assert_array_almost_equal(soln.y, soln_base.y, 4) options_fail = { "max_order_bdf": -1, @@ -817,7 +908,9 @@ def construct_model(): # Compare output to sol_all for varname in [*output_variables, *model_vars]: - self.assertTrue(np.allclose(sol[varname].data, sol_all[varname].data)) + np.testing.assert_array_almost_equal( + sol[varname](t_eval), sol_all[varname](t_eval), 3 + ) # Check that the missing variables are not available in the solution for varname in inaccessible_vars: @@ -860,12 +953,13 @@ def test_with_output_variables_and_sensitivities(self): disc = pybamm.Discretisation(mesh, model.default_spatial_methods) disc.process_model(model) - t_eval = np.linspace(0, 100, 100) + t_eval = np.linspace(0, 100, 5) options = { "linear_solver": "SUNLinSol_KLU", "jacobian": "sparse", "num_threads": 4, + "max_num_steps": 1000, } if form == "iree": options["jax_evaluator"] = "iree" @@ -912,7 +1006,10 @@ def test_with_output_variables_and_sensitivities(self): tol = 1e-5 if form != "iree" else 1e-2 # iree has reduced precision for varname in output_variables: np.testing.assert_array_almost_equal( - sol[varname].data, sol_all[varname].data, tol + sol[varname](t_eval), + sol_all[varname](t_eval), + tol, + err_msg=f"Failed for {varname} with form {form}", ) # Mock a 1D current collector and initialise (none in the model) @@ -943,7 +1040,7 @@ def test_with_output_variables_and_event_termination(self): parameter_values=parameter_values, solver=pybamm.IDAKLUSolver(output_variables=["Terminal voltage [V]"]), ) - sol = sim.solve(np.linspace(0, 3600, 1000)) + sol = sim.solve(np.linspace(0, 3600, 2)) self.assertEqual(sol.termination, "event: Minimum voltage [V]") # create an event that doesn't require the state vector @@ -961,9 +1058,45 @@ def test_with_output_variables_and_event_termination(self): parameter_values=parameter_values, solver=pybamm.IDAKLUSolver(output_variables=["Terminal voltage [V]"]), ) - sol3 = sim3.solve(np.linspace(0, 3600, 1000)) + sol3 = sim3.solve(np.linspace(0, 3600, 2)) self.assertEqual(sol3.termination, "event: Minimum voltage [V]") + def test_simulation_period(self): + model = pybamm.lithium_ion.DFN() + parameter_values = pybamm.ParameterValues("Chen2020") + solver = pybamm.IDAKLUSolver() + + experiment = pybamm.Experiment( + ["Charge at C/10 for 10 seconds"], period="0.1 seconds" + ) + + sim = pybamm.Simulation( + model, + parameter_values=parameter_values, + experiment=experiment, + solver=solver, + ) + sol = sim.solve() + + np.testing.assert_array_almost_equal(sol.t, np.arange(0, 10.1, 0.1), decimal=4) + + def test_interpolate_time_step_start_offset(self): + model = pybamm.lithium_ion.SPM() + experiment = pybamm.Experiment( + [ + "Discharge at C/10 for 10 seconds", + "Charge at C/10 for 10 seconds", + ], + period="1 seconds", + ) + solver = pybamm.IDAKLUSolver() + sim = pybamm.Simulation(model, experiment=experiment, solver=solver) + sol = sim.solve() + np.testing.assert_equal( + sol.sub_solutions[0].t[-1] + pybamm.settings.step_start_offset, + sol.sub_solutions[1].t[0], + ) + if __name__ == "__main__": print("Add -v for more debug output")