Skip to content

Commit

Permalink
Merge pull request #5 from ConnorStoneAstro/adderrors
Browse files Browse the repository at this point in the history
Add custom errors to caskade
  • Loading branch information
ConnorStoneAstro authored Oct 26, 2024
2 parents 0a02276 + dd2ce64 commit 136b6cb
Show file tree
Hide file tree
Showing 12 changed files with 691 additions and 442 deletions.
708 changes: 354 additions & 354 deletions docs/source/notebooks/BeginnersGuide.ipynb

Large diffs are not rendered by default.

35 changes: 34 additions & 1 deletion src/caskade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,42 @@
from .module import Module
from .param import Param
from .tests import test
from .errors import (
CaskadeException,
GraphError,
NodeConfigurationError,
ParamConfigurationError,
ParamTypeError,
ActiveStateError,
FillDynamicParamsError,
FillDynamicParamsTensorError,
FillDynamicParamsSequenceError,
FillDynamicParamsMappingError,
)
from .warnings import CaskadeWarning, InvalidValueWarning


__version__ = VERSION
__author__ = "Connor Stone and Alexandre Adam"

__all__ = ("Node", "Module", "Param", "ActiveContext", "ValidContext", "forward", "test")
__all__ = (
"Node",
"Module",
"Param",
"ActiveContext",
"ValidContext",
"forward",
"test",
"CaskadeException",
"GraphError",
"NodeConfigurationError",
"ParamConfigurationError",
"ParamTypeError",
"ActiveStateError",
"FillDynamicParamsError",
"FillDynamicParamsTensorError",
"FillDynamicParamsSequenceError",
"FillDynamicParamsMappingError",
"CaskadeWarning",
"InvalidValueWarning",
)
16 changes: 10 additions & 6 deletions src/caskade/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional, Union

from .errors import GraphError, NodeConfigurationError


class Node(object):
"""
Expand Down Expand Up @@ -30,8 +32,10 @@ class Node(object):
def __init__(self, name: Optional[str] = None):
if name is None:
name = self.__class__.__name__
assert isinstance(name, str), f"{self.__class__.__name__} name must be a string"
assert "|" not in name, f"{self.__class__.__name__} cannot contain '|'"
if not isinstance(name, str):
raise NodeConfigurationError(f"{self.__class__.__name__} name must be a string")
if "|" in name:
raise NodeConfigurationError(f"{self.__class__.__name__} cannot contain '|'")
self._name = name
self._children = {}
self._parents = set()
Expand Down Expand Up @@ -80,12 +84,12 @@ def link(self, key: Union[str, "Node"], child: Optional["Node"] = None):
key = child.name
# Avoid double linking to the same object
if key in self.children:
raise ValueError(f"Child key {key} already linked to parent {self.name}")
raise GraphError(f"Child key {key} already linked to parent {self.name}")
if child in self.children.values():
raise ValueError(f"Child {child.name} already linked to parent {self.name}")
raise GraphError(f"Child {child.name} already linked to parent {self.name}")
# avoid cycles
if self in child.topological_ordering():
raise ValueError(
raise GraphError(
f"Linking {child.name} to {self.name} would create a cycle in the graph"
)

Expand Down Expand Up @@ -175,7 +179,7 @@ def add_node(node, dot):
if node in components:
return
dot.attr("node", **node.graphviz_types[node._type])
dot.node(str(id(node)), f"{node.__class__.__name__}('{node.name}')")
dot.node(str(id(node)), repr(node))
components.add(node)

for child in node.children.values():
Expand Down
2 changes: 1 addition & 1 deletion src/caskade/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def wrapped(self, *args, **kwargs):
args = args[:-1]
else:
raise ValueError(
f"Params must be provided for dynamic modules. Expected {len(self.dynamic_params)} params."
f"Params must be provided for a top level @forward method. Either by keyword 'method(params=params)' or as the last positional argument 'method(a, b, c, params)'"
)

with ActiveContext(self):
Expand Down
93 changes: 93 additions & 0 deletions src/caskade/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from math import prod
from textwrap import dedent


class CaskadeException(Exception):
"""Base class for all exceptions in Caskade."""


class GraphError(CaskadeException):
"""Class for graph exceptions in Caskade."""


class NodeConfigurationError(CaskadeException):
"""Class for node configuration exceptions in Caskade."""


class ParamConfigurationError(NodeConfigurationError):
"""Class for parameter configuration exceptions in Caskade."""


class ParamTypeError(CaskadeException):
"""Class for exceptions related to the type of a parameter in Caskade."""


class ActiveStateError(CaskadeException):
"""Class for exceptions related to the active state of a node in Caskade."""


class FillDynamicParamsError(CaskadeException):
"""Class for exceptions related to filling dynamic parameters in Caskade."""


class FillDynamicParamsTensorError(FillDynamicParamsError):
def __init__(self, name, input_params, dynamic_params):
fullnumel = sum(max(1, prod(p.shape)) for p in dynamic_params)
message = dedent(
f"""
For flattened Tensor input, the (last) dim of the Tensor should
equal the sum of all flattened dynamic params ({fullnumel}).
Input params shape {input_params.shape} does not match dynamic
params shape of: {name}.
Registered dynamic params (name: shape):
{', '.join(f"{repr(p)}: {str(p.shape)}" for p in dynamic_params)}"""
)
super().__init__(message)


class FillDynamicParamsSequenceError(FillDynamicParamsError):
def __init__(self, name, input_params, dynamic_params, dynamic_modules):
message = dedent(
f"""
Input params length ({len(input_params)}) does not match dynamic
params length ({len(dynamic_params)}) or number of dynamic
modules ({len(dynamic_modules)}) of: {name}.
Registered dynamic modules:
{', '.join(repr(m) for m in dynamic_modules)}
Registered dynamic params:
{', '.join(repr(p) for p in dynamic_params)}"""
)
super().__init__(message)


class FillDynamicParamsMappingError(FillDynamicParamsError):
def __init__(self, name, children, dynamic_modules, missing_key=None, missing_param=None):
if missing_key is not None:
message = dedent(
f"""
Input params key "{missing_key}" not found in dynamic modules or children of: {name}.
Registered dynamic modules:
{', '.join(repr(m) for m in dynamic_modules)}
Registered dynamic children:
{', '.join(repr(c) for c in children.values() if c.dynamic)}"""
)
else:
message = dedent(
f"""
Dynamic param "{missing_param.name}" not filled with given input params dict passed to {name}.
Dynamic param parent(s):
{', '.join(repr(p) for p in missing_param.parents)}
Registered dynamic modules:
{', '.join(repr(m) for m in dynamic_modules)}
Registered dynamic children:
{', '.join(repr(c) for c in children.values() if c.dynamic)}"""
)
super().__init__(message)
64 changes: 37 additions & 27 deletions src/caskade/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@

from .base import Node
from .param import Param
from .errors import (
ActiveStateError,
ParamConfigurationError,
FillDynamicParamsTensorError,
FillDynamicParamsSequenceError,
FillDynamicParamsMappingError,
)


class Module(Node):
Expand Down Expand Up @@ -100,7 +107,8 @@ def fill_params(self, params: Union[Tensor, Sequence, Mapping], local=False):
the dictionary, but you will get an error eventually if a value is
missing.
"""
assert self.active, "Module must be active to fill params"
if not self.active:
raise ActiveStateError("Module must be active to fill params")

if self.valid_context and not local:
params = self.from_valid(params)
Expand All @@ -114,24 +122,19 @@ def fill_params(self, params: Union[Tensor, Sequence, Mapping], local=False):
pos = 0
for param in dynamic_params:
if not isinstance(param.shape, tuple):
raise ValueError(
raise ParamConfigurationError(
f"Param {param.name} has no shape. dynamic parameters must have a shape to use Tensor input."
)
# Handle scalar parameters
size = max(1, prod(param.shape))
try:
param._value = params[..., pos : pos + size].view(B + param.shape)
except (RuntimeError, IndexError):
fullnumel = sum(max(1, prod(p.shape)) for p in dynamic_params)
raise AssertionError(
f"Input params shape {params.shape} does not match dynamic params shape of {self.name}. Make sure the last dimension has size equal to the sum of all dynamic params sizes ({fullnumel})."
)
raise FillDynamicParamsTensorError(self.name, params, dynamic_params)

pos += size
if pos != params.shape[-1]:
fullnumel = sum(max(1, prod(p.shape)) for p in dynamic_params)
raise AssertionError(
f"Input params length {params.shape} does not match dynamic params length ({fullnumel}) of {self.name}. Not all dynamic params were filled."
)
raise FillDynamicParamsTensorError(self.name, params, dynamic_params)
elif isinstance(params, Sequence):
if len(params) == len(dynamic_params):
for param, value in zip(dynamic_params, params):
Expand All @@ -140,8 +143,8 @@ def fill_params(self, params: Union[Tensor, Sequence, Mapping], local=False):
for module, value in zip(self.dynamic_modules.values(), params):
module.fill_params(value, local=True)
else:
raise AssertionError(
f"Input params length ({len(params)}) does not match dynamic params length ({len(dynamic_params)}) or number of dynamic modules ({len(self.dynamic_modules)}) for {self.name}"
raise FillDynamicParamsSequenceError(
self.name, params, dynamic_params, self.dynamic_modules
)
elif isinstance(params, Mapping):
for key in params:
Expand All @@ -150,19 +153,26 @@ def fill_params(self, params: Union[Tensor, Sequence, Mapping], local=False):
elif key in self.children and self[key].dynamic:
self[key]._value = params[key]
else:
raise ValueError(
f"Key {key} not found in dynamic modules or {self.name} children"
raise FillDynamicParamsMappingError(
self.name, self.children, self.dynamic_modules, missing_key=key
)
if not local:
for param in dynamic_params:
if param._value is None:
raise FillDynamicParamsMappingError(
self.name, self.children, self.dynamic_modules, missing_param=param
)
else:
raise ValueError(
f"Input params type {type(params)} not supported. Should be Tensor, Sequence or Mapping."
raise TypeError(
f"Input params type {type(params)} not supported. Should be Tensor, Sequence, or Mapping."
)

def clear_params(self):
"""Set all dynamic parameters to None and live parameters to LiveParam.
This is to be used on exiting an `ActiveContext` and so should not be
used by a user."""
assert self.active, "Module must be active to clear params"
if not self.active:
raise ActiveStateError("Module must be active to clear params")

for param in self.dynamic_params + self.pointer_params:
param._value = None
Expand Down Expand Up @@ -203,8 +213,8 @@ def to_valid(self, params: Union[Tensor, Sequence, Mapping], local=False):
for module, value in zip(self.dynamic_modules.values(), params):
valid_params.append(module.to_valid(value, local=True))
else:
raise AssertionError(
f"Input params length ({len(valid_params)}) does not match dynamic params length ({len(dynamic_params)}) or number of dynamic children ({len(self.children)})"
raise FillDynamicParamsSequenceError(
self.name, params, dynamic_params, self.dynamic_modules
)
elif isinstance(params, Mapping):
valid_params = {}
Expand All @@ -214,11 +224,11 @@ def to_valid(self, params: Union[Tensor, Sequence, Mapping], local=False):
elif key in self.children and self[key].dynamic:
valid_params[key] = self[key].to_valid(params[key])
else:
raise ValueError(
f"Key {key} not found in dynamic modules or {self.name} children"
raise FillDynamicParamsMappingError(
self.name, self.children, self.dynamic_modules, missing_key=key
)
else:
raise ValueError(
raise TypeError(
f"Input params type {type(params)} not supported. Should be Tensor, Sequence, or Mapping."
)
return valid_params
Expand Down Expand Up @@ -249,8 +259,8 @@ def from_valid(self, valid_params: Union[Tensor, Sequence, Mapping], local=False
for module, value in zip(self.dynamic_modules.values(), valid_params):
params.append(module.from_valid(value, local=True))
else:
raise AssertionError(
f"Input params length ({len(params)}) does not match dynamic params length ({len(dynamic_params)}) or number of dynamic children ({len(self.children)})"
raise FillDynamicParamsSequenceError(
self.name, valid_params, dynamic_params, self.dynamic_modules
)
elif isinstance(valid_params, Mapping):
params = {}
Expand All @@ -262,11 +272,11 @@ def from_valid(self, valid_params: Union[Tensor, Sequence, Mapping], local=False
elif key in self.children and self[key].dynamic:
params[key] = self[key].from_valid(valid_params[key])
else:
raise ValueError(
f"Key {key} not found in dynamic modules or {self.name} children"
raise FillDynamicParamsMappingError(
self.name, self.children, self.dynamic_modules, missing_key=key
)
else:
raise ValueError(
raise TypeError(
f"Input params type {type(valid_params)} not supported. Should be Tensor, Sequence or Mapping."
)
return params
Expand Down
Loading

0 comments on commit 136b6cb

Please sign in to comment.