Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1108 change t_eval list to linspace #1113

Merged
merged 4 commits into from
Jul 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

## Features

- Added support for sensitivity calculations to the casadi solver ([#1109](https://github.com/pybamm-team/PyBaMM/pull/1109))
- Added support for index 1 semi-explicit dae equations and sensitivity calculations to JAX BDF solver ([#1107](https://github.com/pybamm-team/PyBaMM/pull/1107))
- Allowed keyword arguments to be passed to `Simulation.plot()` ([#1099](https://github.com/pybamm-team/PyBaMM/pull/1099))

## Optimizations

## Bug fixes

- `t_eval` now gets changed to a `linspace` if a list of length 2 is passed ([#1113](https://github.com/pybamm-team/PyBaMM/pull/1113))
- Fixed bug when setting a function with an `InputParameter` ([#1111](https://github.com/pybamm-team/PyBaMM/pull/1111))

## Breaking changes

- Renamed `quick_plot_vars` to `output_variables` in `Simulation` to be consistent with `QuickPlot`. Passing `quick_plot_vars` to `Simulation.plot()` has been deprecated and `output_variables` should be passed instead ([#1099](https://github.com/pybamm-team/PyBaMM/pull/1099))
Expand Down
13 changes: 1 addition & 12 deletions pybamm/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,17 +330,6 @@ def solve(
solver = self.solver

if self.operating_mode in ["without experiment", "drive cycle"]:
# If t_eval is provided as [t0, tf] return the solution at 100 points
if isinstance(t_eval, list):
if len(t_eval) != 2:
raise pybamm.SolverError(
"'t_eval' can be provided as an array of times at which to "
"return the solution, or as a list [t0, tf] where t0 is the "
"initial time and tf is the final time, but has been provided "
"as a list of length {}.".format(len(t_eval))
)
else:
t_eval = np.linspace(t_eval[0], t_eval[-1], 100)

if self.operating_mode == "without experiment":
if t_eval is None:
Expand Down Expand Up @@ -403,13 +392,13 @@ def solve(
pybamm.SolverWarning,
)

self.t_eval = t_eval
self._solution = solver.solve(
self.built_model,
t_eval,
external_variables=external_variables,
inputs=inputs,
)
self.t_eval = self._solution.t * self.model.timescale.evaluate()

elif self.operating_mode == "with experiment":
if t_eval is not None:
Expand Down
13 changes: 13 additions & 0 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,19 @@ def solve(self, model, t_eval=None, external_variables=None, inputs=None):
t_eval = np.array([0])
else:
raise ValueError("t_eval cannot be None")
# If t_eval is provided as [t0, tf] return the solution at 100 points
elif isinstance(t_eval, list):
if len(t_eval) == 1 and self.algebraic_solver is True:
pass
elif len(t_eval) != 2:
raise pybamm.SolverError(
"'t_eval' can be provided as an array of times at which to "
"return the solution, or as a list [t0, tf] where t0 is the "
"initial time and tf is the final time, but has been provided "
"as a list of length {}.".format(len(t_eval))
)
else:
t_eval = np.linspace(t_eval[0], t_eval[-1], 100)

# Make sure t_eval is monotonic
if (np.diff(t_eval) < 0).any():
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def test_t_eval(self):

# tets list gets turned into np.linspace(t0, tf, 100)
sim.solve(t_eval=[0, 10])
np.testing.assert_array_equal(sim.t_eval, np.linspace(0, 10, 100))
np.testing.assert_array_almost_equal(sim.t_eval, np.linspace(0, 10, 100))


if __name__ == "__main__":
Expand Down