Skip to content

Commit

Permalink
[numerics] Use Const1 objects in Python
Browse files Browse the repository at this point in the history
- replace Python lambda function defining constant Func1 variants by
  pre-existing Const1 class defined in C++
  • Loading branch information
ischoegl committed Jan 23, 2020
1 parent e765b71 commit 1c0e123
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
9 changes: 7 additions & 2 deletions interfaces/cython/cantera/_cantera.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ cdef extern from "cantera/numerics/Func1.h":
CxxTabulated1(vector[double]&, vector[double]&) except +translate_exception
double eval(double) except +translate_exception

cdef cppclass CxxConst1 "Cantera::Const1":
CxxConst1(double) except +translate_exception
double eval(double) except +translate_exception

cdef extern from "cantera/base/xml.h" namespace "Cantera":
cdef cppclass XML_Node:
XML_Node* findByName(string)
Expand Down Expand Up @@ -1015,8 +1019,9 @@ cdef class Func1:
cdef CxxFunc1* func
cdef object callable
cdef object exception
cpdef void __set_callback(self, object) except *
cpdef void __set_tables(self, object, object) except *
cpdef void _set_callback(self, object) except *
cpdef void _set_const(self, double) except *
cpdef void _set_tables(self, object, object) except *

cdef class ReactorBase:
cdef CxxReactorBase* rbase
Expand Down
17 changes: 10 additions & 7 deletions interfaces/cython/cantera/func1.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,24 @@ cdef class Func1:
c = args[0]
if hasattr(c, '__call__'):
# callback function
self.__set_callback(c)
self._set_callback(c)
else:
arr = np.array(c)
try:
if arr.ndim == 0:
# handle constants or unsized numpy arrays
k = float(c)
self.__set_callback(lambda t: k)
self._set_const(k)
elif arr.size == 1:
# handle lists, tuples or numpy arrays with a single element
k = float(c[0])
self.__set_callback(lambda t: k)
self._set_const(k)
elif arr.ndim == 2:
# tabulated function (single argument)
if arr.shape[1] == 2:
time = arr[:, 0]
fval = arr[:, 1]
self.__set_tables(time, fval)
self._set_tables(time, fval)
else:
raise ValueError(
"Invalid dimensions: specification of "
Expand All @@ -121,17 +121,20 @@ cdef class Func1:
elif len(args) == 2:
# tabulated function (two arguments mimic C++ interface)
time, fval = args
self.__set_tables(time, fval)
self._set_tables(time, fval)

else:
raise ValueError("Invalid number of arguments")


cpdef void __set_callback(self, c) except *:
cpdef void _set_callback(self, c) except *:
self.callable = c
self.func = new CxxFunc1(func_callback, <void*>self)

cpdef void __set_tables(self, time, fval) except *:
cpdef void _set_const(self, double c) except *:
self.func = <CxxFunc1*>(new CxxConst1(c))

cpdef void _set_tables(self, time, fval) except *:
cdef vector[double] tvec, fvec
for t in time:
tvec.push_back(t)
Expand Down

0 comments on commit 1c0e123

Please sign in to comment.