Skip to content

Commit

Permalink
Merge pull request #962 from pybamm-team/issue-961-interpolants
Browse files Browse the repository at this point in the history
#961 fix interpolant ids
  • Loading branch information
valentinsulzer authored Apr 17, 2020
2 parents 9b72b1e + c93b6c8 commit 2382672
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

## Bug fixes

- Fixed `Interpolant` ids to allow processing ([#962](https://github.com/pybamm-team/PyBaMM/pull/962)
- Changed simulation attributes to assign copies rather than the objects themselves ([#952](https://github.com/pybamm-team/PyBaMM/pull/952)
- Added default values to base model so that it works with the `Simulation` class ([#952](https://github.com/pybamm-team/PyBaMM/pull/952)
- Fixed solver to recompute initial conditions when inputs are changed ([#951](https://github.com/pybamm-team/PyBaMM/pull/951)
Expand Down
35 changes: 32 additions & 3 deletions pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ class Interpolant(pybamm.Function):
"""

def __init__(
self, data, child, name=None, interpolator="cubic spline", extrapolate=True
self,
data,
child,
name=None,
interpolator="cubic spline",
extrapolate=True,
entries_string=None,
):
if data.ndim != 2 or data.shape[1] != 2:
raise ValueError(
Expand All @@ -56,22 +62,45 @@ def __init__(
name = "interpolating function ({})".format(name)
else:
name = "interpolating function"
self.data = data
self.entries_string = entries_string
super().__init__(
interpolating_function, child, name=name, derivative="derivative"
)
# Store information as attributes
self.data = data
self.x = data[:, 0]
self.y = data[:, 1]
self.interpolator = interpolator
self.extrapolate = extrapolate

@property
def entries_string(self):
return self._entries_string

@entries_string.setter
def entries_string(self, value):
# We must include the entries in the hash, since different arrays can be
# indistinguishable by class, name and domain alone
# Slightly different syntax for sparse and non-sparse matrices
if value is not None:
self._entries_string = value
else:
entries = self.data
self._entries_string = entries.tostring()

def set_id(self):
""" See :meth:`pybamm.Symbol.set_id()`. """
self._id = hash(
(self.__class__, self.name, self.entries_string) + tuple(self.domain)
)

def _function_new_copy(self, children):
""" See :meth:`Function._function_new_copy()` """
return pybamm.Interpolant(
self.data,
*children,
name=self.name,
interpolator=self.interpolator,
extrapolate=self.extrapolate
extrapolate=self.extrapolate,
entries_string=self.entries_string
)
14 changes: 12 additions & 2 deletions tests/unit/test_parameters/test_parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,11 @@ def test_process_interpolant(self):
x = np.linspace(0, 10)[:, np.newaxis]
data = np.hstack([x, 2 * x])
parameter_values = pybamm.ParameterValues(
{"a": 3.01, "Diffusivity": ("times two", data)}
{"a": 3.01, "Times two": ("times two", data)}
)

a = pybamm.Parameter("a")
func = pybamm.FunctionParameter("Diffusivity", {"a": a})
func = pybamm.FunctionParameter("Times two", {"a": a})

processed_func = parameter_values.process_symbol(func)
self.assertIsInstance(processed_func, pybamm.Interpolant)
Expand All @@ -346,6 +346,16 @@ def test_process_interpolant(self):
processed_diff_func = parameter_values.process_symbol(diff_func)
self.assertEqual(processed_diff_func.evaluate(), 2)

# interpolant defined up front
interp2 = pybamm.Interpolant(data, a)
processed_interp2 = parameter_values.process_symbol(interp2)
self.assertEqual(processed_interp2.evaluate(), 6.02)

data3 = np.hstack([x, 3 * x])
interp3 = pybamm.Interpolant(data3, a)
processed_interp3 = parameter_values.process_symbol(interp3)
self.assertEqual(processed_interp3.evaluate(), 9.03)

def test_interpolant_against_function(self):
parameter_values = pybamm.ParameterValues({})
parameter_values.update(
Expand Down

0 comments on commit 2382672

Please sign in to comment.