Skip to content

Commit

Permalink
#920 raise error if timescale depends on inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
rtimms committed May 29, 2020
1 parent b494776 commit 7572a09
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 12 deletions.
8 changes: 7 additions & 1 deletion pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,13 @@ def set_up(self, model, inputs=None):
inputs = inputs or {}

# Set model timescale
model.timescale_eval = model.timescale.evaluate(inputs=inputs)
try:
model.timescale_eval = model.timescale.evaluate()
except KeyError as e:
raise pybamm.SolverError(
"The model timescale is a function of an input parameter "
"(original error: {})".format(e)
)

if (
isinstance(self, (pybamm.CasadiSolver, pybamm.CasadiAlgebraicSolver))
Expand Down
10 changes: 3 additions & 7 deletions pybamm/solvers/processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,8 @@ def __init__(self, base_variable, solution, known_evals=None, warn=True):
self.known_evals = known_evals
self.warn = warn

# Set timescale -- used evaluated timescale if available (to account
# for inputs set during solve)
try:
self.timescale = solution.model.timescale_eval
except AttributeError:
self.timescale = solution.model.timescale.evaluate()
# Set timescale
self.timescale = solution.model.timescale.evaluate()
self.t_pts = self.t_sol * self.timescale

# Store spatial variables to get scales
Expand Down Expand Up @@ -498,7 +494,7 @@ def get_spatial_scale(self, name, domain=None):
if self.warn:
pybamm.logger.warning(
"No scale set for spatial variable {}. "
"Using default of 1 [m]".format(name)
"Using default of 1 [m].".format(name)
)
scale = 1
return scale
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_solvers/test_external_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ def test_on_dfn(self):
model = pybamm.lithium_ion.DFN()
geometry = model.default_geometry
param = model.default_parameter_values
param.update({"Electrode height [m]": "[input]"})
param.update({"Negative electrode conductivity [S.m-1]": "[input]"})
param.process_model(model)
param.process_geometry(geometry)
inputs = {"Electrode height [m]": e_height}
inputs = {"Negative electrode conductivity [S.m-1]": e_height}
var = pybamm.standard_spatial_vars
var_pts = {var.x_n: 5, var.x_s: 5, var.x_p: 5, var.r_n: 10, var.r_p: 10}
spatial_methods = model.default_spatial_methods
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_solvers/test_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,18 @@ def test_convert_to_casadi_format(self):
self.assertEqual(model.convert_to_format, "casadi")
pybamm.set_logging_level("WARNING")

def test_timescale_input_fail(self):
# Make sure timescale can't depend on inputs
model = pybamm.BaseModel()
v = pybamm.Variable("v")
model.rhs = {v: -1}
model.initial_conditions = {v: 1}
a = pybamm.InputParameter("a")
model.timescale = a
solver = pybamm.BaseSolver()
with self.assertRaisesRegex(pybamm.SolverError, "The model timescale"):
solver.set_up(model, inputs={"a": 10})


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_solvers/test_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_solution_evals_with_inputs(self):
model = pybamm.lithium_ion.SPM()
geometry = model.default_geometry
param = model.default_parameter_values
param.update({"Electrode height [m]": "[input]"})
param.update({"Negative electrode conductivity [S.m-1]": "[input]"})
param.process_model(model)
param.process_geometry(geometry)
var = pybamm.standard_spatial_vars
Expand All @@ -163,7 +163,7 @@ def test_solution_evals_with_inputs(self):
spatial_methods=spatial_methods,
solver=solver,
)
inputs = {"Electrode height [m]": 0.1}
inputs = {"Negative electrode conductivity [S.m-1]": 0.1}
sim.solve(t_eval=np.linspace(0, 10, 10), inputs=inputs)
time = sim.solution["Time [h]"](sim.solution.t)
self.assertEqual(len(time), 10)
Expand Down

0 comments on commit 7572a09

Please sign in to comment.