Skip to content

Commit

Permalink
[numerics] Simplify TabulatedFunction constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
ischoegl authored and speth committed Apr 5, 2020
1 parent 1faa762 commit 8a0ac7e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 63 deletions.
67 changes: 16 additions & 51 deletions interfaces/cython/cantera/func1.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -106,68 +106,33 @@ cdef class Func1:
cdef class TabulatedFunction(Func1):
"""
A `TabulatedFunction` object representing a tabulated function is defined by
sample points and corresponding function values. Inputs are specified either
by two iterable objects containing sample point location and function
values, or a single array that concatenates those inputs in two rows or
columns. Between sample points, values are evaluated based on the optional
argument ``method``, which has to be supplied as a keyword; options are
``'linear'`` (linear interpolation, default) or ``'previous'`` (nearest
previous value). Outside the sample interval, the value at the closest end
point is returned.
Examples for `TabulatedFunction` objects using a single (two-dimensional)
array as input are::
>>> t1 = TabulatedFunction([[0, 2], [1, 1], [2, 0]])
sample points and corresponding function values. Inputs are specified by
two iterable objects containing sample point location and function values.
Between sample points, values are evaluated based on the optional argument
``method``; options are ``'linear'`` (linear interpolation, default) or
``'previous'`` (nearest previous value). Outside the sample interval, the
value at the closest end point is returned.
Examples for `TabulatedFunction` objects are::
>>> t1 = TabulatedFunction([0, 1, 2], [2, 1, 0])
>>> [t1(v) for v in [-0.5, 0, 0.5, 1.5, 2, 2.5]]
[2.0, 2.0, 1.5, 0.5, 0.0, 0.0]
>>> t2 = TabulatedFunction(np.array([0, 1, 2]), (2, 1, 0))
>>> t2 = TabulatedFunction(np.array([0, 1, 2]), np.array([2, 1, 0]))
>>> [t2(v) for v in [-0.5, 0, 0.5, 1.5, 2, 2.5]]
[2.0, 2.0, 1.5, 0.5, 0.0, 0.0]
where the optional ``method`` keyword argument changes the type of
interpolation from the ``'linear'`` default to ``'previous'``::
The optional ``method`` keyword argument changes the type of interpolation
from the ``'linear'`` default to ``'previous'``::
>>> t3 = TabulatedFunction([[0, 2], [1, 1], [2, 0]], method='previous')
>>> t3 = TabulatedFunction([0, 1, 2], [2, 1, 0], method='previous')
>>> [t3(v) for v in [-0.5, 0, 0.5, 1.5, 2, 2.5]]
[2.0, 2.0, 2.0, 1.0, 0.0, 0.0]
Alternatively, a `TabulatedFunction` can be defined using two input arrays::
>>> t4 = TabulatedFunction([0, 1, 2], [2, 1, 0])
>>> [t4(v) for v in [-0.5, 0, 0.5, 1.5, 2, 2.5]]
[2.0, 2.0, 1.5, 0.5, 0.0, 0.0]
"""

def __init__(self, *args, method='linear'):
if len(args) == 1:
# tabulated function (single argument)
arr = np.array(args[0])
if arr.ndim == 2:
if arr.shape[1] == 2:
time = arr[:, 0]
fval = arr[:, 1]
elif arr.shape[0] == 2:
time = arr[0, :]
fval = arr[1, :]
else:
raise ValueError("Invalid dimensions: specification of "
"tabulated function with a single array "
"requires two rows or columns")
self._set_tables(time, fval, stringify(method))
else:
raise TypeError("'TabulatedFunction' must be constructed from "
"a numeric array with two dimensions")

elif len(args) == 2:
# tabulated function (two arrays mimic C++ interface)
time, fval = args
self._set_tables(time, fval, stringify(method))

else:
raise ValueError("Invalid number of arguments (one or two "
"arguments containing tabulated values)")
def __init__(self, time, fval, method='linear'):
self._set_tables(time, fval, stringify(method))

cpdef void _set_tables(self, time, fval, string method) except *:
tt = np.asarray(time, dtype=np.double)
Expand Down
22 changes: 10 additions & 12 deletions interfaces/cython/cantera/test/test_func1.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def test_tabulated2(self):
self.assertNear(f, fcn(t))

def test_tabulated3(self):
time = 0, 1, 2,
fval = 2, 1, 0,
fcn = ct.TabulatedFunction(time, fval)
self.assertNear(fcn(-1), fval[0])
self.assertNear(fcn(3), fval[-1])

def test_tabulated4(self):
time = np.array([0, 1, 2])
fval = np.array([2, 1, 0])
fcn = ct.TabulatedFunction(time, fval)
Expand All @@ -95,23 +102,14 @@ def test_tabulated3(self):
for t, f in zip(tt, ff):
self.assertNear(f, fcn(t))

def test_tabulated4(self):
time = 0, 1, 2,
fval = 2, 1, 0,
fcn = ct.TabulatedFunction(time, fval)
self.assertNear(fcn(-1), fval[0])
self.assertNear(fcn(3), fval[-1])

def test_tabulated5(self):
fcn = ct.TabulatedFunction([[0, 2], [1, 1], [2, 0]], method='previous')
time = [0, 1, 2]
fval = [2, 1, 0]
fcn = ct.TabulatedFunction(time, fval, method='previous')
val = np.array([fcn(v) for v in [-0.5, 0, 0.5, 1.5, 2, 2.5]])
self.assertArrayNear(val, np.array([2.0, 2.0, 2.0, 1.0, 0.0, 0.0]))

def test_tabulated_failures(self):
with self.assertRaisesRegex(ValueError, 'Invalid number of arguments'):
ct.TabulatedFunction(1, 2, 3)
with self.assertRaisesRegex(ValueError, 'Invalid dimensions'):
ct.TabulatedFunction(np.zeros((3, 3)))
with self.assertRaisesRegex(ValueError, 'do not match'):
ct.TabulatedFunction(range(2), range(3))
with self.assertRaisesRegex(ValueError, 'must not be empty'):
Expand Down

0 comments on commit 8a0ac7e

Please sign in to comment.