Skip to content

Commit

Permalink
Merge pull request #3 from ConnorStoneAstro/metadata
Browse files Browse the repository at this point in the history
feat: valid, cyclic, and units added with tests
  • Loading branch information
ConnorStoneAstro authored Oct 23, 2024
2 parents a9e8341 + f59d883 commit bdd76b5
Show file tree
Hide file tree
Showing 7 changed files with 573 additions and 220 deletions.
379 changes: 190 additions & 189 deletions docs/source/notebooks/BeginnersGuide.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/caskade/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ._version import version as VERSION # noqa

from .base import Node
from .context import ActiveContext
from .context import ActiveContext, ValidContext
from .decorators import forward
from .module import Module
from .param import Param
Expand All @@ -11,4 +11,4 @@
__version__ = VERSION
__author__ = "Connor and Alexandre"

__all__ = ("Node", "Module", "Param", "ActiveContext", "forward", "test")
__all__ = ("Node", "Module", "Param", "ActiveContext", "ValidContext", "forward", "test")
16 changes: 16 additions & 0 deletions src/caskade/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,19 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, traceback):
self.module.clear_params()
self.module.active = False


class ValidContext:
"""
Context manager to set valid values for parameters. Only inside a
ValidContext will parameters automatically be assumed valid.
"""

def __init__(self, module: Module):
self.module = module

def __enter__(self):
self.module.valid_context = True

def __exit__(self, exc_type, exc_value, traceback):
self.module.valid_context = False
144 changes: 125 additions & 19 deletions src/caskade/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from math import prod

from torch import Tensor
import torch

from .base import Node
from .param import Param
Expand Down Expand Up @@ -58,13 +59,24 @@ def __init__(self, name: Optional[str] = None):
self.dynamic_params = ()
self.pointer_params = ()
self._type = "module"
self.valid_context = False

def update_graph(self):
"""Maintain a tuple of dynamic and live parameters at all points lower
in the DAG."""
super().update_graph()
self.dynamic_params = tuple(self.topological_ordering("dynamic"))
self.pointer_params = tuple(self.topological_ordering("pointer"))
self._n_dynamic_children = sum(1 for child in self.children.values() if child.dynamic)

@property
def dynamic(self):
"""Return True if the module has dynamic parameters"""
return self.dynamic_params != ()

@property
def n_dynamic_children(self):
return self._n_dynamic_children

def fill_params(self, params: Union[Tensor, Sequence, Mapping]):
"""
Expand All @@ -89,6 +101,9 @@ def fill_params(self, params: Union[Tensor, Sequence, Mapping]):
"""
assert self.active, "Module must be active to fill params"

if self.valid_context:
params = self.from_valid(params)

if isinstance(params, Tensor):
# check for batch dimension
batch = len(params.shape) > 1
Expand Down Expand Up @@ -120,32 +135,27 @@ def fill_params(self, params: Union[Tensor, Sequence, Mapping]):
if len(params) == len(self.dynamic_params):
for param, value in zip(self.dynamic_params, params):
param._value = value
elif len(params) <= len(self.children):
elif len(params) == self.n_dynamic_children:
i = 0
keys = list(self.children.keys())
try:
for param in params:
while True: # find next dynamic param, or Module
if isinstance(self[keys[i]], Param) and self[keys[i]].dynamic:
self[keys[i]]._value = param
i += 1
break
elif isinstance(self[keys[i]], Module):
self[keys[i]].fill_params(param)
i += 1
break
keys = tuple(self.children.keys())
for value in params:
while i < len(keys): # find next dynamic param, or Module
if isinstance(self[keys[i]], Param) and self[keys[i]].dynamic:
self[keys[i]]._value = value
i += 1
except Exception as e:
raise RuntimeError(
f"Error filling params: {e}. Filling params with a list-by-children, rather than one element per dynamic parameter is tricky, consider using alternate format."
)
break
elif isinstance(self[keys[i]], Module) and self[keys[i]].dynamic:
self[keys[i]].fill_params(value)
i += 1
break
i += 1
else:
raise AssertionError(
f"Input params length ({len(params)}) does not match dynamic params length ({len(self.dynamic_params)})"
f"Input params length ({len(params)}) does not match dynamic params length ({len(self.dynamic_params)}) or number of dynamic children ({self.n_dynamic_children})"
)
elif isinstance(params, Mapping):
for key in params:
if key in self.children:
if key in self.children and self[key].dynamic:
if isinstance(self[key], Param):
self[key]._value = params[key]
else: # assumed Module
Expand Down Expand Up @@ -178,6 +188,102 @@ def fill_kwargs(self, keys: tuple[str]) -> dict[str, Tensor]:
kwargs[key] = self[key].value
return kwargs

def to_valid(self, params: Union[Tensor, Sequence, Mapping]):
"""Convert input params to valid params."""
if isinstance(params, Tensor):
valid_params = torch.zeros_like(params)
batch = len(params.shape) > 1
if batch:
*B, _ = params.shape
pos = 0
for param in self.dynamic_params:
size = max(1, prod(param.shape)) # Handle scalar parameters
get_shape = tuple(B) + param.shape if batch else param.shape
return_shape = params[..., pos : pos + size].shape
valid_params[..., pos : pos + size] = param.to_valid(
params[..., pos : pos + size].view(get_shape)
).view(return_shape)
pos += size
elif isinstance(params, Sequence):
valid_params = []
if len(params) == len(self.dynamic_params):
for param, value in zip(self.dynamic_params, params):
valid_params.append(param.to_valid(value))
elif len(params) == self.n_dynamic_children:
i = 0
keys = tuple(self.children.keys())
for value in params:
while i < len(keys): # find next dynamic param, or Module
if self[keys[i]].dynamic:
valid_params.append(self[keys[i]].to_valid(value))
i += 1
break
i += 1
else:
raise AssertionError(
f"Input params length ({len(valid_params)}) does not match dynamic params length ({len(self.dynamic_params)}) or number of dynamic children ({len(self.children)})"
)
elif isinstance(params, Mapping):
valid_params = {}
for key in params:
if key in self.children:
valid_params[key] = self[key].to_valid(params[key])
else:
raise ValueError(f"Key {key} not found in {self.name} children")
else:
raise ValueError(
f"Input params type {type(params)} not supported. Should be Tensor, Sequence, or Mapping."
)
return valid_params

def from_valid(self, valid_params: Union[Tensor, Sequence, Mapping]):
"""Convert valid params to input params."""
if isinstance(valid_params, Tensor):
params = torch.zeros_like(valid_params)
batch = len(valid_params.shape) > 1
if batch:
*B, _ = valid_params.shape
pos = 0
for param in self.dynamic_params:
size = max(1, prod(param.shape))
get_shape = tuple(B) + param.shape if batch else param.shape
return_shape = valid_params[..., pos : pos + size].shape
params[..., pos : pos + size] = param.from_valid(
valid_params[..., pos : pos + size].view(get_shape)
).view(return_shape)
pos += size
elif isinstance(valid_params, Sequence):
params = []
if len(valid_params) == len(self.dynamic_params):
for param, value in zip(self.dynamic_params, valid_params):
params.append(param.from_valid(value))
elif len(valid_params) == self.n_dynamic_children:
i = 0
keys = tuple(self.children.keys())
for value in valid_params:
while i < len(keys): # find next dynamic param, or Module
if self[keys[i]].dynamic:
params.append(self[keys[i]].from_valid(value))
i += 1
break
i += 1
else:
raise AssertionError(
f"Input params length ({len(valid_params)}) does not match dynamic params length ({len(self.dynamic_params)}) or number of dynamic children ({len(self.children)})"
)
elif isinstance(valid_params, Mapping):
params = {}
for key in valid_params:
if key in self.children:
params[key] = self[key].from_valid(valid_params[key])
else:
raise ValueError(f"Key {key} not found in {self.name} children")
else:
raise ValueError(
f"Input params type {type(valid_params)} not supported. Should be Tensor, Sequence or Mapping."
)
return params

@property
def _name(self) -> str:
return self.__name
Expand Down
111 changes: 111 additions & 0 deletions src/caskade/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
from torch import Tensor
from torch import pi

from .base import Node

Expand Down Expand Up @@ -38,6 +39,12 @@ class Param(Node):
The value of the parameter. Defaults to None meaning dynamic.
shape: (Optional[tuple[int, ...]], optional)
The shape of the parameter. Defaults to () meaning scalar.
cyclic: (bool, optional)
Whether the parameter is cyclic, such as a rotation from 0 to 2pi. Defaults to False.
valid: (Optional[tuple[Union[Tensor, float, int, None]]], optional)
The valid range of the parameter. Defaults to None meaning all of -inf to inf is valid.
units: (Optional[str], optional)
The units of the parameter. Defaults to None.
"""

graphviz_types = {
Expand All @@ -51,6 +58,9 @@ def __init__(
name: str,
value: Optional[Union[Tensor, float, int]] = None,
shape: Optional[tuple[int, ...]] = (),
cyclic: bool = False,
valid: Optional[tuple[Union[Tensor, float, int, None]]] = None,
units: Optional[str] = None,
):
super().__init__(name=name)
if value is None:
Expand All @@ -65,6 +75,9 @@ def __init__(
shape == () or shape is None or shape == value.shape
), f"Shape {shape} does not match value shape {value.shape}"
self.value = value
self.cyclic = cyclic
self.valid = valid
self.units = units

@property
def dynamic(self) -> bool:
Expand Down Expand Up @@ -144,5 +157,103 @@ def to(self, device=None, dtype=None):
super().to(device=device, dtype=dtype)
if self.static:
self._value = self._value.to(device=device, dtype=dtype)
if self.valid[0] is not None:
self.valid = (self.valid[0].to(device=device, dtype=dtype), self.valid[1])
if self.valid[1] is not None:
self.valid = (self.valid[0], self.valid[1].to(device=device, dtype=dtype))

return self

@property
def cyclic(self):
return self._cyclic

@cyclic.setter
def cyclic(self, cyclic: bool):
self._cyclic = cyclic
try:
self.valid = self._valid
except AttributeError:
pass

@property
def valid(self):
return self._valid

@valid.setter
def valid(self, valid: tuple[Union[Tensor, float, int, None]]):
if valid is None:
valid = (None, None)

assert isinstance(valid, tuple) and len(valid) == 2, "Valid must be a tuple of length 2"

if valid == (None, None):
assert not self.cyclic, "Cannot set valid to None for cyclic parameter"
self.to_valid = self._to_valid_base
self.from_valid = self._from_valid_base
elif valid[0] is None:
assert not self.cyclic, "Cannot set left valid to None for cyclic parameter"
self.to_valid = self._to_valid_rightvalid
self.from_valid = self._from_valid_rightvalid
valid = (None, torch.as_tensor(valid[1]))
elif valid[1] is None:
assert not self.cyclic, "Cannot set right valid to None for cyclic parameter"
self.to_valid = self._to_valid_leftvalid
self.from_valid = self._from_valid_leftvalid
valid = (torch.as_tensor(valid[0]), None)
else:
if self.cyclic:
self.to_valid = self._to_valid_cyclic
self.from_valid = self._from_valid_cyclic
else:
self.to_valid = self._to_valid_fullvalid
self.from_valid = self._from_valid_fullvalid
valid = (torch.as_tensor(valid[0]), torch.as_tensor(valid[1]))

self._valid = valid

def _to_valid_base(self, value):
if self.pointer:
raise ValueError("Cannot apply valid transformation to pointer parameter")
return value

def _to_valid_fullvalid(self, value):
value = self._to_valid_base(value)
return torch.tan((value - self.valid[0]) * pi / (self.valid[1] - self.valid[0]) - pi / 2)

def _to_valid_cyclic(self, value):
value = self._to_valid_base(value)
return (value - self.valid[0]) % (self.valid[1] - self.valid[0]) + self.valid[0]

def _to_valid_leftvalid(self, value):
value = self._to_valid_base(value)
return value - 1.0 / (value - self.valid[0])

def _to_valid_rightvalid(self, value):
value = self._to_valid_base(value)
return value + 1.0 / (self.valid[1] - value)

def _from_valid_base(self, value):
if self.pointer:
raise ValueError("Cannot apply valid transformation to pointer parameter")
return value

def _from_valid_fullvalid(self, value):
value = self._from_valid_base(value)
value = (torch.atan(value) + pi / 2) * (self.valid[1] - self.valid[0]) / pi + self.valid[0]
return value

def _from_valid_cyclic(self, value):
value = self._from_valid_base(value)
value = (value - self.valid[0]) % (self.valid[1] - self.valid[0]) + self.valid[0]
return value

def _from_valid_leftvalid(self, value):
value = self._from_valid_base(value)
value = (value + self.valid[0] + ((value - self.valid[0]) ** 2 + 4).sqrt()) / 2
return value

def _from_valid_rightvalid(self, value):
value = self._from_valid_base(value)
value = (value + self.valid[1] - ((value - self.valid[1]) ** 2 + 4).sqrt()) / 2
return value
Loading

0 comments on commit bdd76b5

Please sign in to comment.