diff --git a/CHANGELOG.md b/CHANGELOG.md index d40f40642d..6e455c9611 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ ## Bug fixes +- Fixed bug raised if function returns a scalar ([#919](https://github.com/pybamm-team/PyBaMM/pull/919)) - Updated Getting started notebook 2 ([#903](https://github.com/pybamm-team/PyBaMM/pull/903)) - Reformatted external circuit submodels ([#879](https://github.com/pybamm-team/PyBaMM/pull/879)) - Some bug fixes to generalize specifying models that aren't battery models, see [#846](https://github.com/pybamm-team/PyBaMM/issues/846) diff --git a/pybamm/parameters/parameter_values.py b/pybamm/parameters/parameter_values.py index 81a6c6861b..7b2b1756b0 100644 --- a/pybamm/parameters/parameter_values.py +++ b/pybamm/parameters/parameter_values.py @@ -524,6 +524,9 @@ def _process_symbol(self, symbol): # return differentiated function new_diff_variable = self.process_symbol(symbol.diff_variable) function_out = function.diff(new_diff_variable) + # Convert possible float output to a pybamm scalar + if isinstance(function_out, numbers.Number): + return pybamm.Scalar(function_out) # Process again just to be sure return self.process_symbol(function_out) diff --git a/tests/unit/test_parameters/test_parameter_values.py b/tests/unit/test_parameters/test_parameter_values.py index 30a75ddd47..c9d81797d6 100644 --- a/tests/unit/test_parameters/test_parameter_values.py +++ b/tests/unit/test_parameters/test_parameter_values.py @@ -300,6 +300,7 @@ def test_process_function_parameter(self): "a": 3, "func": pybamm.load_function("process_symbol_test_function.py"), "const": 254, + "float_func": lambda x: 42, } ) a = pybamm.InputParameter("a") @@ -320,6 +321,11 @@ def test_process_function_parameter(self): processed_diff_func = parameter_values.process_symbol(diff_func) self.assertEqual(processed_diff_func.evaluate(u={"a": 3}), 123) + # function parameter that returns a python float + func = pybamm.FunctionParameter("float_func", a) + processed_func = parameter_values.process_symbol(func) + self.assertEqual(processed_func.evaluate(), 42) + # function itself as input (different to the variable being an input) parameter_values = pybamm.ParameterValues({"func": "[input]"}) a = pybamm.Scalar(3)