Skip to content

Commit

Permalink
Adding documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
ConnorStoneAstro committed Oct 17, 2024
1 parent 797eb54 commit 47039c6
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 13 deletions.
17 changes: 12 additions & 5 deletions src/caskade/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Optional
from typing import Optional, Union

import torch


class Node:
class Node(object):
"""
Base graph node class for caskade objects.
Expand Down Expand Up @@ -51,7 +51,8 @@ def children(self) -> dict:
def parents(self) -> set:
return self._parents

def link(self, key, child):
def link(self, key: str, child: "Node"):
"""Link the current `Node` object to another `Node` object as a child."""
# Avoid double linking to the same object
if key in self.children:
raise ValueError(f"Child key {key} already linked to parent {self.name}")
Expand All @@ -63,7 +64,8 @@ def link(self, key, child):
child._parents.add(self)
self.update_dynamic_params()

def unlink(self, key):
def unlink(self, key: Union[str, "Node"]):
"""Unlink the current `Node` object from another `Node` object which is a child."""
if isinstance(key, Node):
for node in self.children:
if self.children[node] == key:
Expand All @@ -74,7 +76,8 @@ def unlink(self, key):
del self._children[key]
self.update_dynamic_params()

def topological_ordering(self, with_type=None) -> tuple:
def topological_ordering(self, with_type: Optional[str] = None) -> tuple:
"""Return a topological ordering of the graph below the current node."""
ordering = [self]
for node in self.children.values():
for subnode in node.topological_ordering():
Expand All @@ -85,6 +88,8 @@ def topological_ordering(self, with_type=None) -> tuple:
return tuple(filter(lambda n: n._type == with_type, ordering))

def update_dynamic_params(self):
"""Update the dynamic parameters of the current node and all children.
This is intended to be overridden."""
for parent in self.parents:
parent.update_dynamic_params()

Expand Down Expand Up @@ -121,6 +126,8 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]
child.to(device=device, dtype=dtype)

def graph_dict(self) -> dict:
"""Return a dictionary representation of the graph below the current
node."""
graph = {
f"{self.name}|{self._type}": {},
}
Expand Down
9 changes: 5 additions & 4 deletions src/caskade/context.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Union, Mapping, Sequence

from torch import Tensor

from .module import Module


class ActiveContext:
"""
Context manager to activate a module for a simulation. Only inside an
ActiveContext is it possible to fill/clear the dynamic and live parameters.
"""

def __init__(self, module: Module):
self.module = module

Expand Down
1 change: 1 addition & 0 deletions src/caskade/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def forward(method):
for arg in inspect.signature(method).parameters.values():
if arg.default is not arg.empty:
method_kwargs.append(arg.name)
method_kwargs = tuple(method_kwargs)

@functools.wraps(method)
def wrapped(self, *args, **kwargs):
Expand Down
38 changes: 34 additions & 4 deletions src/caskade/module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence, Mapping, Optional
from typing import Sequence, Mapping, Optional, Union, Any
from math import prod

from torch import Tensor
Expand Down Expand Up @@ -71,11 +71,33 @@ def batch(self, value):
self._batch = value

def update_dynamic_params(self):
"""Maintain a tuple of dynamic and live parameters at all points lower
in the DAG."""
super().update_dynamic_params()
self.dynamic_params = tuple(self.topological_ordering("dynamic"))
self.live_params = tuple(self.topological_ordering("live"))

def fill_params(self, params):
def fill_params(self, params: Union[Tensor, Sequence, Mapping]):
"""
Fill the dynamic parameters of the module with the input values from
params.
Parameters
----------
params: (Union[Tensor, Sequence, Mapping])
The input values to fill the dynamic parameters with. The input can
be a Tensor, a Sequence, or a Mapping. If the input is a Tensor, the
values are filled in order of the dynamic parameters. `params`
should be a flattened tensor with all parameters concatenated in the
order of the dynamic parameters. If `self.batch` is `True` then all
dimensions but the last one are considered batch dimensions. If the
input is a Sequence, the values are filled in order of the dynamic
parameters. If the input is a Mapping, the values are filled by
matching the keys of the Mapping to the names of the dynamic
parameters. Note that the system does not check for missing keys in
the dictionary, but you will get an error eventually if a value is
missing.
"""
assert self.active, "Module must be active to fill params"

if isinstance(params, Tensor):
Expand Down Expand Up @@ -119,6 +141,9 @@ def fill_params(self, params):
)

def clear_params(self):
"""Set all dynamic parameters to None and live parameters to LiveParam.
This is to be used on exiting an `ActiveContext` and so should not be
used by a user."""
assert self.active, "Module must be active to clear params"

for param in self.dynamic_params:
Expand All @@ -127,10 +152,15 @@ def clear_params(self):
for param in self.live_params:
param.value = LiveParam

def fill_kwargs(self, keys) -> dict[str, Tensor]:
def fill_kwargs(self, keys: tuple[str]) -> dict[str, Tensor]:
"""
Fill the kwargs for an `@forward` method with the values of the dynamic
parameters. The requested keys are matched to names of `Param` objects
owned by the `Module`.
"""
return {key: getattr(self, key).value for key in keys}

def __setattr__(self, key, value):
def __setattr__(self, key: str, value: Any):
try:
if key in self.children and isinstance(self.children[key], Param):
self.children[key].value = value
Expand Down

0 comments on commit 47039c6

Please sign in to comment.