Skip to content

Commit

Permalink
#600 add more functionality to Interpolant class
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Oct 13, 2019
1 parent f5f7e42 commit 2916272
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 13 deletions.
10 changes: 5 additions & 5 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def __init__(self, child):
super().__init__(np.cosh, child)

def _diff(self, children):
""" See :meth:`pybamm.Symbol._diff()`. """
""" See :meth:`pybamm.Function._diff()`. """
return Sinh(children[0])


Expand All @@ -249,7 +249,7 @@ def __init__(self, child):
super().__init__(np.exp, child)

def _diff(self, children):
""" See :meth:`pybamm.Symbol._diff()`. """
""" See :meth:`pybamm.Function._diff()`. """
return Exponential(children[0])


Expand All @@ -265,7 +265,7 @@ def __init__(self, child):
super().__init__(np.log, child)

def _diff(self, children):
""" See :meth:`pybamm.Symbol._diff()`. """
""" See :meth:`pybamm.Function._diff()`. """
return 1 / children[0]


Expand All @@ -291,7 +291,7 @@ def __init__(self, child):
super().__init__(np.sin, child)

def _diff(self, children):
""" See :meth:`pybamm.Symbol._diff()`. """
""" See :meth:`pybamm.Function._diff()`. """
return Cos(children[0])


Expand All @@ -307,7 +307,7 @@ def __init__(self, child):
super().__init__(np.sinh, child)

def _diff(self, children):
""" See :meth:`pybamm.Symbol._diff()`. """
""" See :meth:`pybamm.Function._diff()`. """
return Cosh(children[0])


Expand Down
63 changes: 55 additions & 8 deletions pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,65 @@
# Interpolating class
#
import pybamm
import numpy as np
from scipy import interpolate


class Interpolant(pybamm.Function):
def __init__(self, data, child):
interpolating_function = interpolate.CubicSpline(
data[:, 0], data[:, 1], extrapolate=True
)
"""
Interpolate data in 1D.
Parameters
----------
data : :class:`numpy.ndarray`
child : :class:`pybamm.Symbol`
name : str, optional
interpolator : str, optional
Which interpolator to use ("linear", "pchip" or "cubic spline"). Default is
"pchip".
extrapolate : bool, optional
Whether to extrapolate for points that are outside of the parametrisation
range, or return NaN (following default behaviour from scipy). Default is True.
"""

def __init__(self, data, child, name=None, interpolator="pchip", extrapolate=True):
if data.ndim != 2 or data.shape[1] != 2:
raise ValueError(
"data should have exactly two columns (x and y) but has shape {}".format(
data.shape
)
)
if interpolator == "linear":
if extrapolate is True:
fill_value = "extrapolate"
else:
fill_value = np.nan
interpolating_function = interpolate.interp1d(
data[:, 0], data[:, 1], fill_value=fill_value
)
elif interpolator == "pchip":
interpolating_function = interpolate.PchipInterpolator(
data[:, 0], data[:, 1], extrapolate=extrapolate
)
elif interpolator == "cubic spline":
interpolating_function = interpolate.CubicSpline(
data[:, 0], data[:, 1], extrapolate=extrapolate
)
else:
raise ValueError("interpolator '{}' not recognised".format(interpolator))
super().__init__(interpolating_function, child)
# Overwrite name if given
if name is not None:
self.name = "interpolating function ({})".format(name)
# Store information as attributes
self.interpolator = interpolator
self.extrapolate = extrapolate

def _diff(self, variable):
""" See :meth:`pybamm.Function._diff()`. """
return pybamm.Function(
self._interpolating_function.derivative(), *self.children
)
"""
Overwrite the base Function `_diff` to use `.derivative` directly instead of
autograd.
See :meth:`pybamm.Function._diff()`.
"""
interpolating_function = self.function
return pybamm.Function(interpolating_function.derivative(), *self.children)
59 changes: 59 additions & 0 deletions tests/unit/test_expression_tree/test_interpolant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#
# Tests for the Function classes
#
import pybamm

import unittest
import numpy as np
import autograd.numpy as auto_np
from scipy.interpolate import interp1d


class TestInterpolant(unittest.TestCase):
def test_errors(self):
with self.assertRaisesRegex(ValueError, "data should have exactly two columns"):
pybamm.Interpolant(np.ones(10), None)
with self.assertRaisesRegex(ValueError, "interpolator 'bla' not recognised"):
pybamm.Interpolant(np.ones((10, 2)), None, interpolator="bla")

def test_interpolation(self):
x = np.linspace(0, 1)[:, np.newaxis]
y = pybamm.StateVector(slice(0, 2))
# linear
linear = np.hstack([x, 2 * x])
for interpolator in ["linear", "pchip", "cubic spline"]:
interp = pybamm.Interpolant(linear, y, interpolator=interpolator)
np.testing.assert_array_almost_equal(
interp.evaluate(y=np.array([0.397, 1.5]))[:, 0], np.array([0.794, 3])
)
# square
square = np.hstack([x, x ** 2])
y = pybamm.StateVector(slice(0, 1))
for interpolator in ["linear", "pchip", "cubic spline"]:
interp = pybamm.Interpolant(square, y)
np.testing.assert_array_almost_equal(
interp.evaluate(y=np.array([0.397]))[:, 0], np.array([0.397 ** 2])
)

# with extrapolation set to False
for interpolator in ["linear", "pchip", "cubic spline"]:
interp = pybamm.Interpolant(square, y, extrapolate=False)
np.testing.assert_array_equal(
interp.evaluate(y=np.array([2]))[:, 0], np.array([np.nan])
)

def test_name(self):
a = pybamm.Symbol("a")
x = np.linspace(0, 1)[:, np.newaxis]
interp = pybamm.Interpolant(np.hstack([x, x]), a, "name")
self.assertEqual(interp.name, "interpolating function (name)")


if __name__ == "__main__":
print("Add -v for more debug output")
import sys

if "-v" in sys.argv:
debug = True
pybamm.settings.debug_mode = True
unittest.main()

0 comments on commit 2916272

Please sign in to comment.