-
Notifications
You must be signed in to change notification settings - Fork 0
/
cvode.py
97 lines (76 loc) · 3.43 KB
/
cvode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from ctypes import *
import numpy as np
from sundials import *
sundials = sundialsCDLL("libsundials_cvode.so.6")
sundials.CVodeCreate.argtypes = [c_int, c_void_p]
sundials.CVodeCreate.restype = c_void_p
sundials.CVodeCreate.errcheck = memory_create_check
sundials.CVodeInit.argtypes = [c_void_p, c_void_p, c_double, c_void_p]
sundials.CVodeInit.restype = c_int
sundials.CVodeSStolerances.argtypes = [c_void_p, c_double, c_double]
sundials.CVodeSStolerances.restype = c_int
sundials.CVodeSVtolerances.argtypes = [c_void_p, c_double, c_void_p]
sundials.CVodeSetLinearSolver.argtypes = [c_void_p, c_void_p, c_void_p]
sundials.CVodeSetLinearSolver.restype = c_int
sundials.CVode.argtypes = [c_void_p, c_double, c_void_p, c_void_p, c_int]
sundials.CVode.restype = c_int
sundials.CVodeReInit.argtypes = [c_void_p, c_double, c_void_p]
sundials.CVodeReInit.restype = c_int
sundials.CVodeFree.argtypes = [c_void_p]
sundials.CVodeFree.restype = None
class CVODE(SundialsSolver):
def __init__(self, ode_type, y0, f, t0, rtol, atol, jac=None, **kwargs):
super().__init__(y0)
if ode_type == 'CV_ADAMS':
ode_type = 1
else:
ode_type = 2
u = self.u
LS = self.LS
ctx = self.ctx
cvode_mem = c_void_p(sundials.CVodeCreate(ode_type, ctx))
f_cv = self.to_sundials_function(f)
CF = CFUNCTYPE(c_void_p, c_double, c_void_p, c_void_p, c_void_p)(f_cv)
sundials.CVodeInit(cvode_mem, CF, t0, u)
if isinstance(atol, float):
sundials.CVodeSStolerances(cvode_mem, rtol, atol)
else:
atol_v = nvector.N_VMake_Serial(self.N, atol.ctypes.data_as(POINTER(c_double)), self.ctx)
sundials.CVodeSVtolerances(cvode_mem, rtol, atol_v)
self.atol_v = atol_v
sundials.CVodeSetLinearSolver(cvode_mem, LS, None)
if jac is not None:
def jac_c(v, jv, tt, yy, fy, data, tmp):
y = np.array(np.fromiter(nvector.N_VGetArrayPointer_Serial(yy), dtype=np.float64, count=self.N)).view(self.dtype)
vv = np.array(np.fromiter(nvector.N_VGetArrayPointer_Serial(v), dtype=np.float64, count=self.N)).view(self.dtype)
result = jac(tt, y, vv).view(np.float64)
jv = nvector.N_VGetArrayPointer_Serial(Jv)
for i in range(self.N):
jv[i] = result[i]
return 0
JF = CFUNCTYPE(c_int, c_void_p, c_void_p, c_double, c_void_p, c_void_p, c_void_p, c_void_p)(jac)
self.JF = JF
cvode.CVodeSetJacTimes(cvode_mem, None, JF)
self.ctx = ctx
self.LS = LS
self.u = u
self.cvode_mem = cvode_mem
self.CF = CF
def solve(self, y0, t0, tf, J = None):
self.J = J
yi = nvector.N_VGetArrayPointer_Serial(self.u)
yn = y0.view(dtype=np.float64)
for i in range(self.N):
yi[i] = yn[i]
sundials.CVodeReInit(self.cvode_mem, t0, self.u)
t_out = c_double(0.0)
try:
sundials.CVode(self.cvode_mem, tf, self.u, byref(t_out), 1)
except:
print("CVODE integration failed")
return y0 * np.nan
return np.array(np.fromiter(nvector.N_VGetArrayPointer_Serial(self.u), dtype=np.float64, count=self.N)).view(dtype=self.dtype)
def free(self):
assert (self.cvode_mem is not None)
super().free(self.cvode_mem, sundials.CVodeFree)
self.cvode_mem = None