From 983ac21a45fa6265636a9bdbf79aa0051e2c3639 Mon Sep 17 00:00:00 2001 From: DeLaVlag Date: Tue, 17 Sep 2024 22:22:58 +0200 Subject: [PATCH] signature for elementwise if/else, njit for helper funcs --- .../tvb/simulator/models/k_ion_exchange.py | 122 ++++++++++-------- 1 file changed, 66 insertions(+), 56 deletions(-) diff --git a/tvb_library/tvb/simulator/models/k_ion_exchange.py b/tvb_library/tvb/simulator/models/k_ion_exchange.py index 3c918547f..3940b4aca 100644 --- a/tvb_library/tvb/simulator/models/k_ion_exchange.py +++ b/tvb_library/tvb/simulator/models/k_ion_exchange.py @@ -31,6 +31,7 @@ Giovanni Rabuffo , Carmela Calabrese , Jan Fousek , + Michiel van der Vlag Under the "NextGen" Research Infrastructure Voucher SC3 associated to the HBP Flagship as a Partnering Project (PP) Project leader: Simona Olmi @@ -42,7 +43,7 @@ import numpy -from numba import guvectorize, float64 +from numba import guvectorize, float64, njit class KIonEx(Model): r""" @@ -348,13 +349,65 @@ def V_dot_form(I_Na,I_K,I_Cl,I_pump): def dfun(self, x, c, local_coupling=0.0): - x_ = x - c_ = c + local_coupling * x[0] + x_ = x.reshape(x.shape[:-1]).T + c_ = c.reshape(c.shape[:-1]).T + local_coupling * x[0] + deriv = _numba_dfun(x_, c_, self.E, self.K_bath, self.J, self.eta, self.Delta, self.c_minus, self.R_minus, self.c_plus, self.R_plus, self.Vstar, self.Cm, self.tau_n, self.gamma, self.epsilon) - return deriv + return deriv.T[..., numpy.newaxis] + # helper functions + +@njit +def m_inf(V): + Cmna = -24.0 # mV + DCmna = 12.0 # mV + return 1.0 / (1.0 + numpy.exp((Cmna - V) / DCmna)) + +@njit +def n_inf(V): + Cnk = -19.0 # mV + DCnk = 18.0 # mV #Ok in the paper + return 1.0 / (1.0 + numpy.exp((Cnk - V) / DCnk)) + +@njit +def h(n): + return 1.1 - 1.0 / (1.0 + numpy.exp(-8.0 * (n - 0.4))) + +@njit +def I_K_form(V, n, K_o, K_i): + g_K = 22.0 # nS # maximal potassium conductance + g_Kl = 0.12 # nS # potassium leak conductance + return (g_Kl + g_K * n) * (V - 26.64 * numpy.log(K_o / K_i)) + +@njit +def I_Na_form(V, Na_o, Na_i, n): + g_Na = 40.0 # nS # maximal sodiumconductance + g_Nal = 0.02 # nS # sodium leak conductance + return (g_Nal + g_Na * m_inf(V) * h(n)) * (V - 26.64 * numpy.log(Na_o / Na_i)) + +@njit +def I_Cl_form(V): + g_Cl = 7.5 # nS #Ok in the paper # chloride conductance + Cl_i0 = 5.0 # mMol/m**3 # initial concentration of intracellular Cl + Cl_o0 = 112.0 # mMol/m**3 # initial concentration of extracellular Cl + return g_Cl * (V + 26.64 * numpy.log(Cl_o0 / Cl_i0)) + +@njit +def I_pump_form(Na_i, K_o): + rho = 250. # 250.,#pA # maximal Na/K pump current + Cnap = 21.0 # mol.m**-3 + DCnap = 2.0 # mol.m**-3 + Ckp = 5.5 # mol.m**-3 + DCkp = 1.0 # mol.m**-3 + return rho * ( + 1.0 / (1.0 + numpy.exp((Cnap - Na_i) / DCnap)) * (1.0 / (1.0 + numpy.exp((Ckp - K_o) / DCkp)))) + +# @njit +# def V_dot_form(Cm, I_Na, I_K, I_Cl, I_pump): +# return (-1.0 / Cm) * (I_Na + I_K + I_Cl + I_pump) + -@guvectorize([(float64[:],) * 17], '(n),(m)' + ',()' * 14 + '->(n)', nopython=True) +@guvectorize([(float64[:],) * 17], '(n),(m)' + ',()' * 14 + '->(n)', target='parallel', nopython=True) def _numba_dfun(state_variables, coupling, E, K_bath, J, eta, Delta, c_minus, R_minus, c_plus, R_plus, Vstar, Cm, tau_n, gamma, epsilon, dx): r""" @@ -387,57 +440,15 @@ def _numba_dfun(state_variables, coupling, E, K_bath, J, eta, Delta, c_minus, R_ Coupling_Term = coupling[0] # This zero refers to the first element of cvar (trivial in this case) # Constants - Cnap = 21.0 # mol.m**-3 - DCnap = 2.0 # mol.m**-3 - Ckp = 5.5 # mol.m**-3 - DCkp = 1.0 # mol.m**-3 - Cmna = -24.0 # mV - DCmna = 12.0 # mV - Chn = 0.4 # dimensionless - DChn = -8.0 # dimensionless - Cnk = -19.0 # mV - DCnk = 18.0 # mV #Ok in the paper - g_Cl = 7.5 # nS #Ok in the paper # chloride conductance - g_Na = 40.0 # nS # maximal sodiumconductance - g_K = 22.0 # nS # maximal potassium conductance - g_Nal = 0.02 # nS # sodium leak conductance - g_Kl = 0.12 # nS # potassium leak conductance - rho = 250. # 250.,#pA # maximal Na/K pump current + # Chn = 0.4 # dimensionless + # DChn = -8.0 # dimensionless w_i = 2160.0 # umeter**3 # intracellular volume w_o = 720.0 # umeter**3 # extracellular volume Na_i0 = 16.0 # mMol/m**3 # initial concentration of intracellular Na Na_o0 = 138.0 # mMol/m**3 # initial concentration of extracellular Na K_i0 = 130.0 # mMol/m**3 # initial concentration of intracellular K K_o0 = 4.80 # mMol/m**3 # initial concentration of extracellular K - Cl_i0 = 5.0 # mMol/m**3 # initial concentration of intracellular Cl - Cl_o0 = 112.0 # mMol/m**3 # initial concentration of extracellular Cl - - # helper functions - - def m_inf(V): - return 1.0 / (1.0 + numpy.exp((Cmna - V) / DCmna)) - - def n_inf(V): - return 1.0 / (1.0 + numpy.exp((Cnk - V) / DCnk)) - def h(n): - return 1.1 - 1.0 / (1.0 + numpy.exp(-8.0 * (n - 0.4))) - - def I_K_form(V, n, K_o, K_i): - return (g_Kl + g_K * n) * (V - 26.64 * numpy.log(K_o / K_i)) - - def I_Na_form(V, Na_o, Na_i, n): - return (g_Nal + g_Na * m_inf(V) * h(n)) * (V - 26.64 * numpy.log(Na_o / Na_i)) - - def I_Cl_form(V): - return g_Cl * (V + 26.64 * numpy.log(Cl_o0 / Cl_i0)) - - def I_pump_form(Na_i, K_o): - return rho * ( - 1.0 / (1.0 + numpy.exp((Cnap - Na_i) / DCnap)) * (1.0 / (1.0 + numpy.exp((Ckp - K_o) / DCkp)))) - - def V_dot_form(I_Na, I_K, I_Cl, I_pump): - return (-1.0 / Cm) * (I_Na + I_K + I_Cl + I_pump) beta = w_i / w_o DNa_i = -DKi @@ -457,14 +468,13 @@ def V_dot_form(I_Na, I_K, I_Cl, I_pump): r = R_minus[0] * x / numpy.pi Vdot = (-1.0 / Cm[0]) * (I_Na + I_K + I_Cl + I_pump) - if_xdot = Delta[0] + 2 * R_minus[0] * (V - c_minus[0]) * x - J[0] * r * x - else_xdot = Delta[0] + 2 * R_plus[0] * (V - c_plus[0]) * x - J[0] * r * x - - if_Vdot = Vdot - R_minus[0] * x ** 2 + eta[0] + (R_minus[0] / numpy.pi) * Coupling_Term * (E[0] - V) - else_Vdot = Vdot - R_plus[0] * x ** 2 + eta[0] + (R_minus[0] / numpy.pi) * Coupling_Term * (E[0] - V) + if V <= Vstar[0]: + R, c = R_minus[0], c_minus[0] + else: + R, c = R_plus[0], c_plus[0] - dx[0] = numpy.where(V <= (Vstar * numpy.ones_like(V)), if_xdot, else_xdot)[0] - dx[1] = numpy.where(V <= (Vstar * numpy.ones_like(V)), if_Vdot, else_Vdot)[0] + dx[0] = Delta[0] + 2 * R * (V - c) * x - J[0] * r * x + dx[1] = Vdot - R * x ** 2 + eta[0] + (R_minus[0] * numpy.pi) * Coupling_Term * (E[0] - V) dx[2] = (ninf - n) / tau_n[0] dx[3] = -(gamma[0] / w_i) * (I_K - 2.0 * I_pump) dx[4] = epsilon[0] * (K_bath[0] - K_o)