Skip to content

Commit

Permalink
signature for elementwise if/else, njit for helper funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
DeLaVlag committed Sep 17, 2024
1 parent 003572c commit 983ac21
Showing 1 changed file with 66 additions and 56 deletions.
122 changes: 66 additions & 56 deletions tvb_library/tvb/simulator/models/k_ion_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Giovanni Rabuffo <giovanni.rabuffo@univ-amu.fr>,
Carmela Calabrese <carmela.calabrese@iit.it>,
Jan Fousek <jan.fousek@univ-amu.fr>,
Michiel van der Vlag <m.van.der.vlag@fz-juelich.de>
Under the "NextGen" Research Infrastructure Voucher SC3 associated to the HBP Flagship as a Partnering Project (PP)
Project leader: Simona Olmi <simone.olmi@gmail.com>
Expand All @@ -42,7 +43,7 @@

import numpy

from numba import guvectorize, float64
from numba import guvectorize, float64, njit

class KIonEx(Model):
r"""
Expand Down Expand Up @@ -348,13 +349,65 @@ def V_dot_form(I_Na,I_K,I_Cl,I_pump):

def dfun(self, x, c, local_coupling=0.0):

x_ = x
c_ = c + local_coupling * x[0]
x_ = x.reshape(x.shape[:-1]).T
c_ = c.reshape(c.shape[:-1]).T + local_coupling * x[0]

deriv = _numba_dfun(x_, c_, self.E, self.K_bath, self.J, self.eta, self.Delta, self.c_minus, self.R_minus,
self.c_plus, self.R_plus, self.Vstar, self.Cm, self.tau_n, self.gamma, self.epsilon)
return deriv
return deriv.T[..., numpy.newaxis]
# helper functions

@njit
def m_inf(V):
Cmna = -24.0 # mV
DCmna = 12.0 # mV
return 1.0 / (1.0 + numpy.exp((Cmna - V) / DCmna))

@njit
def n_inf(V):
Cnk = -19.0 # mV
DCnk = 18.0 # mV #Ok in the paper
return 1.0 / (1.0 + numpy.exp((Cnk - V) / DCnk))

@njit
def h(n):
return 1.1 - 1.0 / (1.0 + numpy.exp(-8.0 * (n - 0.4)))

@njit
def I_K_form(V, n, K_o, K_i):
g_K = 22.0 # nS # maximal potassium conductance
g_Kl = 0.12 # nS # potassium leak conductance
return (g_Kl + g_K * n) * (V - 26.64 * numpy.log(K_o / K_i))

@njit
def I_Na_form(V, Na_o, Na_i, n):
g_Na = 40.0 # nS # maximal sodiumconductance
g_Nal = 0.02 # nS # sodium leak conductance
return (g_Nal + g_Na * m_inf(V) * h(n)) * (V - 26.64 * numpy.log(Na_o / Na_i))

@njit
def I_Cl_form(V):
g_Cl = 7.5 # nS #Ok in the paper # chloride conductance
Cl_i0 = 5.0 # mMol/m**3 # initial concentration of intracellular Cl
Cl_o0 = 112.0 # mMol/m**3 # initial concentration of extracellular Cl
return g_Cl * (V + 26.64 * numpy.log(Cl_o0 / Cl_i0))

@njit
def I_pump_form(Na_i, K_o):
rho = 250. # 250.,#pA # maximal Na/K pump current
Cnap = 21.0 # mol.m**-3
DCnap = 2.0 # mol.m**-3
Ckp = 5.5 # mol.m**-3
DCkp = 1.0 # mol.m**-3
return rho * (
1.0 / (1.0 + numpy.exp((Cnap - Na_i) / DCnap)) * (1.0 / (1.0 + numpy.exp((Ckp - K_o) / DCkp))))

# @njit
# def V_dot_form(Cm, I_Na, I_K, I_Cl, I_pump):
# return (-1.0 / Cm) * (I_Na + I_K + I_Cl + I_pump)


@guvectorize([(float64[:],) * 17], '(n),(m)' + ',()' * 14 + '->(n)', nopython=True)
@guvectorize([(float64[:],) * 17], '(n),(m)' + ',()' * 14 + '->(n)', target='parallel', nopython=True)
def _numba_dfun(state_variables, coupling, E, K_bath, J, eta, Delta, c_minus, R_minus, c_plus, R_plus, Vstar, Cm,
tau_n, gamma, epsilon, dx):
r"""
Expand Down Expand Up @@ -387,57 +440,15 @@ def _numba_dfun(state_variables, coupling, E, K_bath, J, eta, Delta, c_minus, R_
Coupling_Term = coupling[0] # This zero refers to the first element of cvar (trivial in this case)

# Constants
Cnap = 21.0 # mol.m**-3
DCnap = 2.0 # mol.m**-3
Ckp = 5.5 # mol.m**-3
DCkp = 1.0 # mol.m**-3
Cmna = -24.0 # mV
DCmna = 12.0 # mV
Chn = 0.4 # dimensionless
DChn = -8.0 # dimensionless
Cnk = -19.0 # mV
DCnk = 18.0 # mV #Ok in the paper
g_Cl = 7.5 # nS #Ok in the paper # chloride conductance
g_Na = 40.0 # nS # maximal sodiumconductance
g_K = 22.0 # nS # maximal potassium conductance
g_Nal = 0.02 # nS # sodium leak conductance
g_Kl = 0.12 # nS # potassium leak conductance
rho = 250. # 250.,#pA # maximal Na/K pump current
# Chn = 0.4 # dimensionless
# DChn = -8.0 # dimensionless
w_i = 2160.0 # umeter**3 # intracellular volume
w_o = 720.0 # umeter**3 # extracellular volume
Na_i0 = 16.0 # mMol/m**3 # initial concentration of intracellular Na
Na_o0 = 138.0 # mMol/m**3 # initial concentration of extracellular Na
K_i0 = 130.0 # mMol/m**3 # initial concentration of intracellular K
K_o0 = 4.80 # mMol/m**3 # initial concentration of extracellular K
Cl_i0 = 5.0 # mMol/m**3 # initial concentration of intracellular Cl
Cl_o0 = 112.0 # mMol/m**3 # initial concentration of extracellular Cl

# helper functions

def m_inf(V):
return 1.0 / (1.0 + numpy.exp((Cmna - V) / DCmna))

def n_inf(V):
return 1.0 / (1.0 + numpy.exp((Cnk - V) / DCnk))

def h(n):
return 1.1 - 1.0 / (1.0 + numpy.exp(-8.0 * (n - 0.4)))

def I_K_form(V, n, K_o, K_i):
return (g_Kl + g_K * n) * (V - 26.64 * numpy.log(K_o / K_i))

def I_Na_form(V, Na_o, Na_i, n):
return (g_Nal + g_Na * m_inf(V) * h(n)) * (V - 26.64 * numpy.log(Na_o / Na_i))

def I_Cl_form(V):
return g_Cl * (V + 26.64 * numpy.log(Cl_o0 / Cl_i0))

def I_pump_form(Na_i, K_o):
return rho * (
1.0 / (1.0 + numpy.exp((Cnap - Na_i) / DCnap)) * (1.0 / (1.0 + numpy.exp((Ckp - K_o) / DCkp))))

def V_dot_form(I_Na, I_K, I_Cl, I_pump):
return (-1.0 / Cm) * (I_Na + I_K + I_Cl + I_pump)

beta = w_i / w_o
DNa_i = -DKi
Expand All @@ -457,14 +468,13 @@ def V_dot_form(I_Na, I_K, I_Cl, I_pump):
r = R_minus[0] * x / numpy.pi
Vdot = (-1.0 / Cm[0]) * (I_Na + I_K + I_Cl + I_pump)

if_xdot = Delta[0] + 2 * R_minus[0] * (V - c_minus[0]) * x - J[0] * r * x
else_xdot = Delta[0] + 2 * R_plus[0] * (V - c_plus[0]) * x - J[0] * r * x

if_Vdot = Vdot - R_minus[0] * x ** 2 + eta[0] + (R_minus[0] / numpy.pi) * Coupling_Term * (E[0] - V)
else_Vdot = Vdot - R_plus[0] * x ** 2 + eta[0] + (R_minus[0] / numpy.pi) * Coupling_Term * (E[0] - V)
if V <= Vstar[0]:
R, c = R_minus[0], c_minus[0]
else:
R, c = R_plus[0], c_plus[0]

dx[0] = numpy.where(V <= (Vstar * numpy.ones_like(V)), if_xdot, else_xdot)[0]
dx[1] = numpy.where(V <= (Vstar * numpy.ones_like(V)), if_Vdot, else_Vdot)[0]
dx[0] = Delta[0] + 2 * R * (V - c) * x - J[0] * r * x
dx[1] = Vdot - R * x ** 2 + eta[0] + (R_minus[0] * numpy.pi) * Coupling_Term * (E[0] - V)
dx[2] = (ninf - n) / tau_n[0]
dx[3] = -(gamma[0] / w_i) * (I_K - 2.0 * I_pump)
dx[4] = epsilon[0] * (K_bath[0] - K_o)

0 comments on commit 983ac21

Please sign in to comment.