Skip to content

Commit

Permalink
New style to define the model (#3)
Browse files Browse the repository at this point in the history
* New style to define the model

* separate ``update()`` function into three functions:

- ``before_integral()``
- ``compute_derivative()``
- ``after_integral()``

* fix bugs and add thalamus single compartment neuron model examples

* update doc
  • Loading branch information
chaoming0625 authored Jul 9, 2024
1 parent 7c5bc9b commit dc3f191
Show file tree
Hide file tree
Showing 15 changed files with 633 additions and 339 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
<p align="center">
<a href="https://pypi.org/project/dendritex/"><img alt="Supported Python Version" src="https://img.shields.io/pypi/pyversions/dendritex"></a>
<a href="https://github.com/chaoming0625/dendritex/blob/main/LICENSE"><img alt="LICENSE" src="https://img.shields.io/badge/License-Apache%202.0-blue.svg"></a>
<a href='https://dendritex.readthedocs.io/en/latest/?badge=latest'>
<img src='https://readthedocs.org/projects/dendritex/badge/?version=latest' alt='Documentation Status' />
<a href='https://dendrite.readthedocs.io/en/latest/?badge=latest'>
<img src='https://readthedocs.org/projects/dendrite/badge/?version=latest' alt='Documentation Status' />
</a>
<a href="https://badge.fury.io/py/dendritex"><img alt="PyPI version" src="https://badge.fury.io/py/dendritex.svg"></a>
<a href="https://github.com/chaoming0625/dendritex/actions/workflows/CI.yml"><img alt="Continuous Integration" src="https://github.com/chaoming0625/dendritex/actions/workflows/CI.yml/badge.svg"></a>
Expand Down Expand Up @@ -46,7 +46,7 @@ The official documentation is hosted on Read the Docs: [https://dendrite.readthe

- [``brainscale``](https://github.com/chaoming0625/brainscale): The scalable online learning framework for biological neural networks.

- [``dendritex``](https://github.com/chaoming0625/dendritex): The dendritic modeling in JAx.
- [``dendritex``](https://github.com/chaoming0625/dendritex): The dendritic modeling in JAX.

- [``braintools``](https://github.com/chaoming0625/braintools): The toolbox for the brain dynamics simulation, training and analysis.

159 changes: 111 additions & 48 deletions dendritex/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from brainstate.mixin import _JointGenericAlias

__all__ = [
'DendriticDynamics',
'State4Integral',
'HHTypedNeuron',
'IonChannel',
'Ion',
Expand All @@ -48,6 +50,20 @@
# - Channel
#

class State4Integral(bst.ShortTermState):
"""
A state that integrates the state of the system to the integral of the state.
Attributes
----------
derivative: The derivative of the state.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.derivative = None


class DendriticDynamics(bst.Dynamics):
"""
Expand Down Expand Up @@ -97,36 +113,20 @@ def __init__(
def current(self, *args, **kwargs):
raise NotImplementedError('Must be implemented by the subclass.')

def before_integral(self, *args, **kwargs):
raise NotImplementedError

def _root_leaf_pair_check(root: type, leaf: 'TreeNode'):
if hasattr(leaf, 'root_type'):
master_type = leaf.root_type
else:
raise ValueError('Child class should define "root_type" to '
'specify the type of the root node. '
f'But we did not found it in {leaf}')
if not issubclass(root, master_type):
raise TypeError(f'Type does not match. {leaf} requires a master with type '
f'of {leaf.root_type}, but the master now is {root}.')
def compute_derivative(self, *args, **kwargs):
raise NotImplementedError('Must be implemented by the subclass.')

def after_integral(self, *args, **kwargs):
raise NotImplementedError

def check_hierarchies(root: type, *leaves, check_fun: Callable = None, **named_leaves):
if check_fun is None:
check_fun = _root_leaf_pair_check
def init_state(self, *args, **kwargs):
raise NotImplementedError

for leaf in leaves:
if isinstance(leaf, bst.Module):
check_fun(root, leaf)
elif isinstance(leaf, (list, tuple)):
check_hierarchies(root, *leaf, check_fun=check_fun)
elif isinstance(leaf, dict):
check_hierarchies(root, **leaf, check_fun=check_fun)
else:
raise ValueError(f'Do not support {type(leaf)}.')
for leaf in named_leaves.values():
if not isinstance(leaf, bst.Module):
raise ValueError(f'Do not support {type(leaf)}. Must be instance of {bst.Module}')
check_fun(root, leaf)
def reset_state(self, *args, **kwargs):
raise NotImplementedError


class Container(bst.mixin.Mixin):
Expand Down Expand Up @@ -182,6 +182,37 @@ def add_elem(self, *elems, **elements):
class TreeNode(bst.mixin.Mixin):
root_type: type

@staticmethod
def _root_leaf_pair_check(root: type, leaf: 'TreeNode'):
if hasattr(leaf, 'root_type'):
root_type = leaf.root_type
else:
raise ValueError('Child class should define "root_type" to '
'specify the type of the root node. '
f'But we did not found it in {leaf}')
if not issubclass(root, root_type):
raise TypeError(f'Type does not match. {leaf} requires a root with type '
f'of {leaf.root_type}, but the root now is {root}.')

@staticmethod
def check_hierarchies(root: type, *leaves, check_fun: Callable = None, **named_leaves):
if check_fun is None:
check_fun = TreeNode._root_leaf_pair_check

for leaf in leaves:
if isinstance(leaf, bst.Module):
check_fun(root, leaf)
elif isinstance(leaf, (list, tuple)):
TreeNode.check_hierarchies(root, *leaf, check_fun=check_fun)
elif isinstance(leaf, dict):
TreeNode.check_hierarchies(root, **leaf, check_fun=check_fun)
else:
raise ValueError(f'Do not support {type(leaf)}.')
for leaf in named_leaves.values():
if not isinstance(leaf, bst.Module):
raise ValueError(f'Do not support {type(leaf)}. Must be instance of {bst.Module}')
check_fun(root, leaf)


class HHTypedNeuron(DendriticDynamics, Container):
"""
Expand All @@ -207,11 +238,28 @@ def add_elem(self, *elems, **elements):
Args:
elements: children objects.
"""
TreeNode.check_hierarchies(type(self), *elems, **elements)
self.ion_channels.update(self._format_elements(object, *elems, **elements))


class IonChannel(DendriticDynamics):
pass
class IonChannel(DendriticDynamics, TreeNode):
def current(self, *args, **kwargs):
raise NotImplementedError

def before_integral(self, *args, **kwargs):
raise NotImplementedError

def compute_derivative(self, *args, **kwargs):
raise NotImplementedError

def after_integral(self, *args, **kwargs):
raise NotImplementedError

def reset_state(self, *args, **kwargs):
raise NotImplementedError

def init_state(self, *args, **kwargs):
raise NotImplementedError


class IonInfo(NamedTuple):
Expand Down Expand Up @@ -249,10 +297,20 @@ def __init__(

self._external_currents: Dict[str, Callable] = dict()

def update(self, V):
def before_integral(self, V):
nodes = self.nodes(level=1, include_self=False).subset(Channel)
for node in nodes.values():
node.before_integral(V, self.pack_info())

def compute_derivative(self, V):
nodes = self.nodes(level=1, include_self=False).subset(Channel)
for node in nodes.values():
node.compute_derivative(V, self.pack_info())

def after_integral(self, V):
nodes = self.nodes(level=1, include_self=False).subset(Channel)
for node in nodes.values():
node.update(V, IonInfo(E=self.E, C=self.C))
node.after_integral(V, self.pack_info())

def current(self, V, include_external: bool = False):
"""
Expand All @@ -266,9 +324,8 @@ def current(self, V, include_external: bool = False):
Current.
"""
nodes = tuple(self.nodes(level=1, include_self=False).subset(Channel).values())
check_hierarchies(type(self), *nodes)

ion_info = IonInfo(E=self.E, C=self.C)
ion_info = self.pack_info()
current = None
if len(nodes) > 0:
for node in nodes:
Expand All @@ -283,21 +340,20 @@ def current(self, V, include_external: bool = False):

def init_state(self, V, batch_size: int = None):
nodes = self.nodes(level=1, include_self=False).subset(Channel).values()
check_hierarchies(type(self), *tuple(nodes))
self.check_hierarchies(type(self), *tuple(nodes))
ion_info = self.pack_info()
for node in nodes:
node: Channel
node.init_state(V, ion_info, batch_size)

def reset_state(self, V, batch_size: int = None):
nodes = self.nodes(level=1, include_self=False).subset(Channel).values()
check_hierarchies(type(self), *tuple(nodes))
ion_info = self.pack_info()
for node in nodes:
node: Channel
node.reset_state(V, ion_info, batch_size)

def add_external_current(self, key: str, fun: Callable):
def register_external_current(self, key: str, fun: Callable):
if key in self._external_currents:
raise ValueError
self._external_currents[key] = fun
Expand All @@ -314,6 +370,7 @@ def add_elem(self, *elems, **elements):
Args:
elements: children objects.
"""
self.check_hierarchies(type(self), *elems, **elements)
self.channels.update(self._format_elements(object, *elems, **elements))


Expand Down Expand Up @@ -350,12 +407,23 @@ def __init__(
self.channels: Dict[str, Channel] = bst.visible_module_dict()
self.channels.update(self._format_elements(Channel, **channels))

def update(self, V):
def before_integral(self, V):
nodes = tuple(self.nodes(level=1, include_self=False).subset(Channel).values())
for node in nodes:
ion_infos = tuple([self._get_ion(ion).pack_info() for ion in node.root_type.__args__])
node.before_integral(V, *ion_infos)

def compute_derivative(self, V):
nodes = tuple(self.nodes(level=1, include_self=False).subset(Channel).values())
for node in nodes:
ion_infos = tuple([self._get_ion(ion).pack_info() for ion in node.root_type.__args__])
node.compute_derivative(V, *ion_infos)

def after_integral(self, V):
nodes = tuple(self.nodes(level=1, include_self=False).subset(Channel).values())
check_hierarchies(self._ion_types, *nodes, check_fun=self._check_hierarchy)
for node in nodes:
ion_infos = tuple([self._get_ion(ion).pack_info() for ion in node.root_type.__args__])
node.update(V, *ion_infos)
node.after_integral(V, *ion_infos)

def current(self, V):
"""Generate ion channel current.
Expand All @@ -367,7 +435,6 @@ def current(self, V):
Current.
"""
nodes = tuple(self.nodes(level=1, include_self=False).subset(Channel).values())
check_hierarchies(self._ion_types, *nodes, check_fun=self._check_hierarchy)

if len(nodes) == 0:
return 0.
Expand All @@ -380,15 +447,14 @@ def current(self, V):

def init_state(self, V, batch_size: int = None):
nodes = self.nodes(level=1, include_self=False).subset(Channel).values()
check_hierarchies(type(self), *tuple(nodes))
self.check_hierarchies(self._ion_types, *tuple(nodes), check_fun=self._check_hierarchy)
for node in nodes:
node: Channel
infos = tuple([self._get_ion(root).pack_info() for root in node.root_type.__args__])
node.reset_state(V, *infos, batch_size)
node.init_state(V, *infos, batch_size)

def reset_state(self, V, batch_size=None):
nodes = tuple(self.nodes(level=1, include_self=False).subset(Channel).values())
check_hierarchies(self._ion_types, *nodes, check_fun=self._check_hierarchy)
for node in nodes:
infos = tuple([self._get_ion(root).pack_info() for root in node.root_type.__args__])
node.reset_state(V, *infos, batch_size)
Expand All @@ -410,13 +476,13 @@ def add_elem(self, *elems, **elements):
Args:
elements: children objects.
"""
check_hierarchies(self._ion_types, *elems, check_fun=self._check_hierarchy, **elements)
self.check_hierarchies(self._ion_types, *elems, check_fun=self._check_hierarchy, **elements)
self.channels.update(self._format_elements(Channel, *elems, **elements))
for elem in tuple(elems) + tuple(elements.values()):
elem: Channel
for ion_root in elem.root_type.__args__:
ion = self._get_ion(ion_root)
ion.add_external_current(elem.name, self._get_ion_fun(ion, elem))
ion.register_external_current(elem.name, self._get_ion_fun(ion, elem))

def _get_ion_fun(self, ion: 'Ion', node: 'Channel'):
def fun(V, ion_info):
Expand Down Expand Up @@ -458,8 +524,5 @@ def mix_ions(*ions) -> MixIons:
return MixIons(*ions)


class Channel(IonChannel, TreeNode):
class Channel(IonChannel):
"""Base class for ion channels."""

def current(self, *args, **kwargs):
raise NotImplementedError('Must be implemented by the subclass.')
Loading

0 comments on commit dc3f191

Please sign in to comment.