-
-
Notifications
You must be signed in to change notification settings - Fork 529
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#600 add more functionality to Interpolant class
- Loading branch information
1 parent
f5f7e42
commit 2916272
Showing
3 changed files
with
119 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# | ||
# Tests for the Function classes | ||
# | ||
import pybamm | ||
|
||
import unittest | ||
import numpy as np | ||
import autograd.numpy as auto_np | ||
from scipy.interpolate import interp1d | ||
|
||
|
||
class TestInterpolant(unittest.TestCase): | ||
def test_errors(self): | ||
with self.assertRaisesRegex(ValueError, "data should have exactly two columns"): | ||
pybamm.Interpolant(np.ones(10), None) | ||
with self.assertRaisesRegex(ValueError, "interpolator 'bla' not recognised"): | ||
pybamm.Interpolant(np.ones((10, 2)), None, interpolator="bla") | ||
|
||
def test_interpolation(self): | ||
x = np.linspace(0, 1)[:, np.newaxis] | ||
y = pybamm.StateVector(slice(0, 2)) | ||
# linear | ||
linear = np.hstack([x, 2 * x]) | ||
for interpolator in ["linear", "pchip", "cubic spline"]: | ||
interp = pybamm.Interpolant(linear, y, interpolator=interpolator) | ||
np.testing.assert_array_almost_equal( | ||
interp.evaluate(y=np.array([0.397, 1.5]))[:, 0], np.array([0.794, 3]) | ||
) | ||
# square | ||
square = np.hstack([x, x ** 2]) | ||
y = pybamm.StateVector(slice(0, 1)) | ||
for interpolator in ["linear", "pchip", "cubic spline"]: | ||
interp = pybamm.Interpolant(square, y) | ||
np.testing.assert_array_almost_equal( | ||
interp.evaluate(y=np.array([0.397]))[:, 0], np.array([0.397 ** 2]) | ||
) | ||
|
||
# with extrapolation set to False | ||
for interpolator in ["linear", "pchip", "cubic spline"]: | ||
interp = pybamm.Interpolant(square, y, extrapolate=False) | ||
np.testing.assert_array_equal( | ||
interp.evaluate(y=np.array([2]))[:, 0], np.array([np.nan]) | ||
) | ||
|
||
def test_name(self): | ||
a = pybamm.Symbol("a") | ||
x = np.linspace(0, 1)[:, np.newaxis] | ||
interp = pybamm.Interpolant(np.hstack([x, x]), a, "name") | ||
self.assertEqual(interp.name, "interpolating function (name)") | ||
|
||
|
||
if __name__ == "__main__": | ||
print("Add -v for more debug output") | ||
import sys | ||
|
||
if "-v" in sys.argv: | ||
debug = True | ||
pybamm.settings.debug_mode = True | ||
unittest.main() |