diff --git a/docs/source/notebooks/BeginnersGuide.ipynb b/docs/source/notebooks/BeginnersGuide.ipynb index 58be889..23b6239 100644 --- a/docs/source/notebooks/BeginnersGuide.ipynb +++ b/docs/source/notebooks/BeginnersGuide.ipynb @@ -100,82 +100,82 @@ "\n", "\n", - "\n", + "\n", "\n", "%3\n", - "\n", - "\n", + "\n", + "\n", "\n", - "139989240272880\n", - "\n", - "Gaussian('my first module')\n", + "140179085468720\n", + "\n", + "Gaussian(my first module)\n", "\n", - "\n", + "\n", "\n", - "139989240271392\n", - "\n", - "Param('x0')\n", + "140179085468672\n", + "\n", + "x0\n", "\n", - "\n", + "\n", "\n", - "139989240272880->139989240271392\n", - "\n", - "\n", + "140179085468720->140179085468672\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139989240273024\n", - "\n", - "Param('q')\n", + "140179085468336\n", + "\n", + "q\n", "\n", - "\n", + "\n", "\n", - "139989240272880->139989240273024\n", - "\n", - "\n", + "140179085468720->140179085468336\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139989240271776\n", - "\n", - "Param('phi')\n", + "140179085469008\n", + "\n", + "phi\n", "\n", - "\n", + "\n", "\n", - "139989240272880->139989240271776\n", - "\n", - "\n", + "140179085468720->140179085469008\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139989240274560\n", - "\n", - "Param('sigma')\n", + "140179085470928\n", + "\n", + "sigma\n", "\n", - "\n", + "\n", "\n", - "139989240272880->139989240274560\n", - "\n", - "\n", + "140179085468720->140179085470928\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139993053830592\n", - "\n", - "Param('I0')\n", + "140179085468432\n", + "\n", + "I0\n", "\n", - "\n", + "\n", "\n", - "139989240272880->139993053830592\n", - "\n", - "\n", + "140179085468720->140179085468432\n", + "\n", + "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -369,94 +369,94 @@ "\n", "\n", - "\n", + "\n", "\n", "%3\n", - "\n", - "\n", + "\n", + "\n", "\n", - "139993079960672\n", - "\n", - "Gaussian('my second module')\n", + "140182370881984\n", + "\n", + "Gaussian(my second module)\n", "\n", - "\n", + "\n", "\n", - "139992458928032\n", - "\n", - "Param('x0')\n", + "140182370882032\n", + "\n", + "x0\n", "\n", - "\n", + "\n", "\n", - "139993079960672->139992458928032\n", - "\n", - "\n", + "140182370881984->140182370882032\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992458927216\n", - "\n", - "Param('q')\n", + "140182370884528\n", + "\n", + "q\n", "\n", - "\n", + "\n", "\n", - "139993079960672->139992458927216\n", - "\n", - "\n", + "140182370881984->140182370884528\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992458926784\n", - "\n", - "Param('phi')\n", + "140182371922800\n", + "\n", + "phi\n", "\n", - "\n", + "\n", "\n", - "139993079960672->139992458926784\n", - "\n", - "\n", + "140182370881984->140182371922800\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992458927792\n", - "\n", - "Param('sigma')\n", + "140182371924000\n", + "\n", + "sigma\n", "\n", - "\n", + "\n", "\n", - "139993079960672->139992458927792\n", - "\n", - "\n", + "140182370881984->140182371924000\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992458925776\n", - "\n", - "Param('I0')\n", + "140182371923280\n", + "\n", + "I0\n", "\n", - "\n", + "\n", "\n", - "139993079960672->139992458925776\n", - "\n", - "\n", + "140182370881984->140182371923280\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139989240271392\n", - "\n", - "Param('x0')\n", + "140179085468672\n", + "\n", + "x0\n", "\n", - "\n", + "\n", "\n", - "139992458928032->139989240271392\n", - "\n", - "\n", + "140182370882032->140179085468672\n", + "\n", + "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 9, @@ -535,172 +535,172 @@ "\n", "\n", - "\n", + "\n", "\n", "%3\n", - "\n", - "\n", + "\n", + "\n", "\n", - "139992458903504\n", - "\n", - "Combined('my combined module')\n", + "140179085468528\n", + "\n", + "Combined(my combined module)\n", "\n", - "\n", + "\n", "\n", - "139989240272880\n", - "\n", - "Gaussian('my first module')\n", + "140179085468720\n", + "\n", + "Gaussian(my first module)\n", "\n", - "\n", + "\n", "\n", - "139992458903504->139989240272880\n", - "\n", - "\n", + "140179085468528->140179085468720\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139993079960672\n", - "\n", - "Gaussian('my second module')\n", + "140182370881984\n", + "\n", + "Gaussian(my second module)\n", "\n", - "\n", + "\n", "\n", - "139992458903504->139993079960672\n", - "\n", - "\n", + "140179085468528->140182370881984\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139989240271392\n", - "\n", - "Param('x0')\n", + "140179085468672\n", + "\n", + "x0\n", "\n", - "\n", + "\n", "\n", - "139989240272880->139989240271392\n", - "\n", - "\n", + "140179085468720->140179085468672\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139989240273024\n", - "\n", - "Param('q')\n", + "140179085468336\n", + "\n", + "q\n", "\n", - "\n", + "\n", "\n", - "139989240272880->139989240273024\n", - "\n", - "\n", + "140179085468720->140179085468336\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139989240271776\n", - "\n", - "Param('phi')\n", + "140179085469008\n", + "\n", + "phi\n", "\n", - "\n", + "\n", "\n", - "139989240272880->139989240271776\n", - "\n", - "\n", + "140179085468720->140179085469008\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139989240274560\n", - "\n", - "Param('sigma')\n", + "140179085470928\n", + "\n", + "sigma\n", "\n", - "\n", + "\n", "\n", - "139989240272880->139989240274560\n", - "\n", - "\n", + "140179085468720->140179085470928\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139993053830592\n", - "\n", - "Param('I0')\n", + "140179085468432\n", + "\n", + "I0\n", "\n", - "\n", + "\n", "\n", - "139989240272880->139993053830592\n", - "\n", - "\n", + "140179085468720->140179085468432\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992458928032\n", - "\n", - "Param('x0')\n", + "140182370882032\n", + "\n", + "x0\n", "\n", - "\n", + "\n", "\n", - "139993079960672->139992458928032\n", - "\n", - "\n", + "140182370881984->140182370882032\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992458927216\n", - "\n", - "Param('q')\n", + "140182370884528\n", + "\n", + "q\n", "\n", - "\n", + "\n", "\n", - "139993079960672->139992458927216\n", - "\n", - "\n", + "140182370881984->140182370884528\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992458926784\n", - "\n", - "Param('phi')\n", + "140182371922800\n", + "\n", + "phi\n", "\n", - "\n", + "\n", "\n", - "139993079960672->139992458926784\n", - "\n", - "\n", + "140182370881984->140182371922800\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992458927792\n", - "\n", - "Param('sigma')\n", + "140182371924000\n", + "\n", + "sigma\n", "\n", - "\n", + "\n", "\n", - "139993079960672->139992458927792\n", - "\n", - "\n", + "140182370881984->140182371924000\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992458925776\n", - "\n", - "Param('I0')\n", + "140182371923280\n", + "\n", + "I0\n", "\n", - "\n", + "\n", "\n", - "139993079960672->139992458925776\n", - "\n", - "\n", + "140182370881984->140182371923280\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992458928032->139989240271392\n", - "\n", - "\n", + "140182370882032->140179085468672\n", + "\n", + "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 11, @@ -761,184 +761,184 @@ "\n", "\n", - "\n", + "\n", "\n", "%3\n", - "\n", - "\n", + "\n", + "\n", "\n", - "139992458903504\n", - "\n", - "Combined('my combined module')\n", + "140179085468528\n", + "\n", + "Combined(my combined module)\n", "\n", - "\n", + "\n", "\n", - "139989240272880\n", - "\n", - "Gaussian('my first module')\n", + "140179085468720\n", + "\n", + "Gaussian(my first module)\n", "\n", - "\n", + "\n", "\n", - "139992458903504->139989240272880\n", - "\n", - "\n", + "140179085468528->140179085468720\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139993079960672\n", - "\n", - "Gaussian('my second module')\n", + "140182370881984\n", + "\n", + "Gaussian(my second module)\n", "\n", - "\n", + "\n", "\n", - "139992458903504->139993079960672\n", - "\n", - "\n", + "140179085468528->140182370881984\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139989240271392\n", - "\n", - "Param('x0')\n", + "140179085468672\n", + "\n", + "x0\n", "\n", - "\n", + "\n", "\n", - "139989240272880->139989240271392\n", - "\n", - "\n", + "140179085468720->140179085468672\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139989240273024\n", - "\n", - "Param('q')\n", + "140179085468336\n", + "\n", + "q\n", "\n", - "\n", + "\n", "\n", - "139989240272880->139989240273024\n", - "\n", - "\n", + "140179085468720->140179085468336\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139989240271776\n", - "\n", - "Param('phi')\n", + "140179085469008\n", + "\n", + "phi\n", "\n", - "\n", + "\n", "\n", - "139989240272880->139989240271776\n", - "\n", - "\n", + "140179085468720->140179085469008\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139989240274560\n", - "\n", - "Param('sigma')\n", + "140179085470928\n", + "\n", + "sigma\n", "\n", - "\n", + "\n", "\n", - "139989240272880->139989240274560\n", - "\n", - "\n", + "140179085468720->140179085470928\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139993053830592\n", - "\n", - "Param('I0')\n", + "140179085468432\n", + "\n", + "I0\n", "\n", - "\n", + "\n", "\n", - "139989240272880->139993053830592\n", - "\n", - "\n", + "140179085468720->140179085468432\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992457824096\n", - "\n", - "Param('time')\n", + "140182369524896\n", + "\n", + "time\n", "\n", - "\n", + "\n", "\n", - "139989240271392->139992457824096\n", - "\n", - "\n", + "140179085468672->140182369524896\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992458928032\n", - "\n", - "Param('x0')\n", + "140182370882032\n", + "\n", + "x0\n", "\n", - "\n", + "\n", "\n", - "139993079960672->139992458928032\n", - "\n", - "\n", + "140182370881984->140182370882032\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992458927216\n", - "\n", - "Param('q')\n", + "140182370884528\n", + "\n", + "q\n", "\n", - "\n", + "\n", "\n", - "139993079960672->139992458927216\n", - "\n", - "\n", + "140182370881984->140182370884528\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992458926784\n", - "\n", - "Param('phi')\n", + "140182371922800\n", + "\n", + "phi\n", "\n", - "\n", + "\n", "\n", - "139993079960672->139992458926784\n", - "\n", - "\n", + "140182370881984->140182371922800\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992458927792\n", - "\n", - "Param('sigma')\n", + "140182371924000\n", + "\n", + "sigma\n", "\n", - "\n", + "\n", "\n", - "139993079960672->139992458927792\n", - "\n", - "\n", + "140182370881984->140182371924000\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992458925776\n", - "\n", - "Param('I0')\n", + "140182371923280\n", + "\n", + "I0\n", "\n", - "\n", + "\n", "\n", - "139993079960672->139992458925776\n", - "\n", - "\n", + "140182370881984->140182371923280\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "139992458928032->139992457824096\n", - "\n", - "\n", + "140182370882032->140182369524896\n", + "\n", + "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 13, @@ -1156,42 +1156,42 @@ "\n", "\n", "
\n", - " \n", + " \n", "
\n", - " \n", + " oninput=\"animf0ea50e791504cda99b2a274f3032690.set_frame(parseInt(this.value));\">\n", "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", "
\n", - "
\n", - " \n", - " \n", - " Once\n", + " \n", - " \n", - " Loop\n", + " \n", - " \n", + " \n", "
\n", "
\n", "
\n", @@ -1201,9 +1201,9 @@ " /* Instantiate the Animation class. */\n", " /* The IDs given should match those used in the template above. */\n", " (function() {\n", - " var img_id = \"_anim_img68590b60fa8144c6992d1c9805748126\";\n", - " var slider_id = \"_anim_slider68590b60fa8144c6992d1c9805748126\";\n", - " var loop_select_id = \"_anim_loop_select68590b60fa8144c6992d1c9805748126\";\n", + " var img_id = \"_anim_imgf0ea50e791504cda99b2a274f3032690\";\n", + " var slider_id = \"_anim_sliderf0ea50e791504cda99b2a274f3032690\";\n", + " var loop_select_id = \"_anim_loop_selectf0ea50e791504cda99b2a274f3032690\";\n", " var frames = new Array(64);\n", " \n", " frames[0] = \"\\\n", @@ -27072,7 +27072,7 @@ " /* set a timeout to make sure all the above elements are created before\n", " the object is initialized. */\n", " setTimeout(function() {\n", - " anim68590b60fa8144c6992d1c9805748126 = new Animation(frames, img_id, slider_id, 59.0,\n", + " animf0ea50e791504cda99b2a274f3032690 = new Animation(frames, img_id, slider_id, 59.0,\n", " loop_select_id);\n", " }, 0);\n", " })()\n", @@ -27124,8 +27124,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "for-loop time taken: 0.3419454097747803\n", - "vmap time taken: 0.0344240665435791\n" + "for-loop time taken: 0.23389983177185059\n", + "vmap time taken: 0.06905579566955566\n" ] } ], @@ -27192,7 +27192,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "batched time taken: 0.03307318687438965\n" + "batched time taken: 0.033159494400024414\n" ] } ], diff --git a/src/caskade/__init__.py b/src/caskade/__init__.py index d43925b..95e00b5 100644 --- a/src/caskade/__init__.py +++ b/src/caskade/__init__.py @@ -6,9 +6,42 @@ from .module import Module from .param import Param from .tests import test +from .errors import ( + CaskadeException, + GraphError, + NodeConfigurationError, + ParamConfigurationError, + ParamTypeError, + ActiveStateError, + FillDynamicParamsError, + FillDynamicParamsTensorError, + FillDynamicParamsSequenceError, + FillDynamicParamsMappingError, +) +from .warnings import CaskadeWarning, InvalidValueWarning __version__ = VERSION __author__ = "Connor Stone and Alexandre Adam" -__all__ = ("Node", "Module", "Param", "ActiveContext", "ValidContext", "forward", "test") +__all__ = ( + "Node", + "Module", + "Param", + "ActiveContext", + "ValidContext", + "forward", + "test", + "CaskadeException", + "GraphError", + "NodeConfigurationError", + "ParamConfigurationError", + "ParamTypeError", + "ActiveStateError", + "FillDynamicParamsError", + "FillDynamicParamsTensorError", + "FillDynamicParamsSequenceError", + "FillDynamicParamsMappingError", + "CaskadeWarning", + "InvalidValueWarning", +) diff --git a/src/caskade/base.py b/src/caskade/base.py index b180dfe..c844353 100644 --- a/src/caskade/base.py +++ b/src/caskade/base.py @@ -1,5 +1,7 @@ from typing import Optional, Union +from .errors import GraphError, NodeConfigurationError + class Node(object): """ @@ -30,8 +32,10 @@ class Node(object): def __init__(self, name: Optional[str] = None): if name is None: name = self.__class__.__name__ - assert isinstance(name, str), f"{self.__class__.__name__} name must be a string" - assert "|" not in name, f"{self.__class__.__name__} cannot contain '|'" + if not isinstance(name, str): + raise NodeConfigurationError(f"{self.__class__.__name__} name must be a string") + if "|" in name: + raise NodeConfigurationError(f"{self.__class__.__name__} cannot contain '|'") self._name = name self._children = {} self._parents = set() @@ -80,12 +84,12 @@ def link(self, key: Union[str, "Node"], child: Optional["Node"] = None): key = child.name # Avoid double linking to the same object if key in self.children: - raise ValueError(f"Child key {key} already linked to parent {self.name}") + raise GraphError(f"Child key {key} already linked to parent {self.name}") if child in self.children.values(): - raise ValueError(f"Child {child.name} already linked to parent {self.name}") + raise GraphError(f"Child {child.name} already linked to parent {self.name}") # avoid cycles if self in child.topological_ordering(): - raise ValueError( + raise GraphError( f"Linking {child.name} to {self.name} would create a cycle in the graph" ) @@ -175,7 +179,7 @@ def add_node(node, dot): if node in components: return dot.attr("node", **node.graphviz_types[node._type]) - dot.node(str(id(node)), f"{node.__class__.__name__}('{node.name}')") + dot.node(str(id(node)), repr(node)) components.add(node) for child in node.children.values(): diff --git a/src/caskade/decorators.py b/src/caskade/decorators.py index 9c483de..17428b2 100644 --- a/src/caskade/decorators.py +++ b/src/caskade/decorators.py @@ -45,7 +45,7 @@ def wrapped(self, *args, **kwargs): args = args[:-1] else: raise ValueError( - f"Params must be provided for dynamic modules. Expected {len(self.dynamic_params)} params." + f"Params must be provided for a top level @forward method. Either by keyword 'method(params=params)' or as the last positional argument 'method(a, b, c, params)'" ) with ActiveContext(self): diff --git a/src/caskade/errors.py b/src/caskade/errors.py new file mode 100644 index 0000000..7da52fe --- /dev/null +++ b/src/caskade/errors.py @@ -0,0 +1,93 @@ +from math import prod +from textwrap import dedent + + +class CaskadeException(Exception): + """Base class for all exceptions in Caskade.""" + + +class GraphError(CaskadeException): + """Class for graph exceptions in Caskade.""" + + +class NodeConfigurationError(CaskadeException): + """Class for node configuration exceptions in Caskade.""" + + +class ParamConfigurationError(NodeConfigurationError): + """Class for parameter configuration exceptions in Caskade.""" + + +class ParamTypeError(CaskadeException): + """Class for exceptions related to the type of a parameter in Caskade.""" + + +class ActiveStateError(CaskadeException): + """Class for exceptions related to the active state of a node in Caskade.""" + + +class FillDynamicParamsError(CaskadeException): + """Class for exceptions related to filling dynamic parameters in Caskade.""" + + +class FillDynamicParamsTensorError(FillDynamicParamsError): + def __init__(self, name, input_params, dynamic_params): + fullnumel = sum(max(1, prod(p.shape)) for p in dynamic_params) + message = dedent( + f""" + For flattened Tensor input, the (last) dim of the Tensor should + equal the sum of all flattened dynamic params ({fullnumel}). + Input params shape {input_params.shape} does not match dynamic + params shape of: {name}. + + Registered dynamic params (name: shape): + {', '.join(f"{repr(p)}: {str(p.shape)}" for p in dynamic_params)}""" + ) + super().__init__(message) + + +class FillDynamicParamsSequenceError(FillDynamicParamsError): + def __init__(self, name, input_params, dynamic_params, dynamic_modules): + message = dedent( + f""" + Input params length ({len(input_params)}) does not match dynamic + params length ({len(dynamic_params)}) or number of dynamic + modules ({len(dynamic_modules)}) of: {name}. + + Registered dynamic modules: + {', '.join(repr(m) for m in dynamic_modules)} + + Registered dynamic params: + {', '.join(repr(p) for p in dynamic_params)}""" + ) + super().__init__(message) + + +class FillDynamicParamsMappingError(FillDynamicParamsError): + def __init__(self, name, children, dynamic_modules, missing_key=None, missing_param=None): + if missing_key is not None: + message = dedent( + f""" + Input params key "{missing_key}" not found in dynamic modules or children of: {name}. + + Registered dynamic modules: + {', '.join(repr(m) for m in dynamic_modules)} + + Registered dynamic children: + {', '.join(repr(c) for c in children.values() if c.dynamic)}""" + ) + else: + message = dedent( + f""" + Dynamic param "{missing_param.name}" not filled with given input params dict passed to {name}. + + Dynamic param parent(s): + {', '.join(repr(p) for p in missing_param.parents)} + + Registered dynamic modules: + {', '.join(repr(m) for m in dynamic_modules)} + + Registered dynamic children: + {', '.join(repr(c) for c in children.values() if c.dynamic)}""" + ) + super().__init__(message) diff --git a/src/caskade/module.py b/src/caskade/module.py index 7cee98f..a60c4a0 100644 --- a/src/caskade/module.py +++ b/src/caskade/module.py @@ -6,6 +6,13 @@ from .base import Node from .param import Param +from .errors import ( + ActiveStateError, + ParamConfigurationError, + FillDynamicParamsTensorError, + FillDynamicParamsSequenceError, + FillDynamicParamsMappingError, +) class Module(Node): @@ -100,7 +107,8 @@ def fill_params(self, params: Union[Tensor, Sequence, Mapping], local=False): the dictionary, but you will get an error eventually if a value is missing. """ - assert self.active, "Module must be active to fill params" + if not self.active: + raise ActiveStateError("Module must be active to fill params") if self.valid_context and not local: params = self.from_valid(params) @@ -114,7 +122,7 @@ def fill_params(self, params: Union[Tensor, Sequence, Mapping], local=False): pos = 0 for param in dynamic_params: if not isinstance(param.shape, tuple): - raise ValueError( + raise ParamConfigurationError( f"Param {param.name} has no shape. dynamic parameters must have a shape to use Tensor input." ) # Handle scalar parameters @@ -122,16 +130,11 @@ def fill_params(self, params: Union[Tensor, Sequence, Mapping], local=False): try: param._value = params[..., pos : pos + size].view(B + param.shape) except (RuntimeError, IndexError): - fullnumel = sum(max(1, prod(p.shape)) for p in dynamic_params) - raise AssertionError( - f"Input params shape {params.shape} does not match dynamic params shape of {self.name}. Make sure the last dimension has size equal to the sum of all dynamic params sizes ({fullnumel})." - ) + raise FillDynamicParamsTensorError(self.name, params, dynamic_params) + pos += size if pos != params.shape[-1]: - fullnumel = sum(max(1, prod(p.shape)) for p in dynamic_params) - raise AssertionError( - f"Input params length {params.shape} does not match dynamic params length ({fullnumel}) of {self.name}. Not all dynamic params were filled." - ) + raise FillDynamicParamsTensorError(self.name, params, dynamic_params) elif isinstance(params, Sequence): if len(params) == len(dynamic_params): for param, value in zip(dynamic_params, params): @@ -140,8 +143,8 @@ def fill_params(self, params: Union[Tensor, Sequence, Mapping], local=False): for module, value in zip(self.dynamic_modules.values(), params): module.fill_params(value, local=True) else: - raise AssertionError( - f"Input params length ({len(params)}) does not match dynamic params length ({len(dynamic_params)}) or number of dynamic modules ({len(self.dynamic_modules)}) for {self.name}" + raise FillDynamicParamsSequenceError( + self.name, params, dynamic_params, self.dynamic_modules ) elif isinstance(params, Mapping): for key in params: @@ -150,19 +153,26 @@ def fill_params(self, params: Union[Tensor, Sequence, Mapping], local=False): elif key in self.children and self[key].dynamic: self[key]._value = params[key] else: - raise ValueError( - f"Key {key} not found in dynamic modules or {self.name} children" + raise FillDynamicParamsMappingError( + self.name, self.children, self.dynamic_modules, missing_key=key ) + if not local: + for param in dynamic_params: + if param._value is None: + raise FillDynamicParamsMappingError( + self.name, self.children, self.dynamic_modules, missing_param=param + ) else: - raise ValueError( - f"Input params type {type(params)} not supported. Should be Tensor, Sequence or Mapping." + raise TypeError( + f"Input params type {type(params)} not supported. Should be Tensor, Sequence, or Mapping." ) def clear_params(self): """Set all dynamic parameters to None and live parameters to LiveParam. This is to be used on exiting an `ActiveContext` and so should not be used by a user.""" - assert self.active, "Module must be active to clear params" + if not self.active: + raise ActiveStateError("Module must be active to clear params") for param in self.dynamic_params + self.pointer_params: param._value = None @@ -203,8 +213,8 @@ def to_valid(self, params: Union[Tensor, Sequence, Mapping], local=False): for module, value in zip(self.dynamic_modules.values(), params): valid_params.append(module.to_valid(value, local=True)) else: - raise AssertionError( - f"Input params length ({len(valid_params)}) does not match dynamic params length ({len(dynamic_params)}) or number of dynamic children ({len(self.children)})" + raise FillDynamicParamsSequenceError( + self.name, params, dynamic_params, self.dynamic_modules ) elif isinstance(params, Mapping): valid_params = {} @@ -214,11 +224,11 @@ def to_valid(self, params: Union[Tensor, Sequence, Mapping], local=False): elif key in self.children and self[key].dynamic: valid_params[key] = self[key].to_valid(params[key]) else: - raise ValueError( - f"Key {key} not found in dynamic modules or {self.name} children" + raise FillDynamicParamsMappingError( + self.name, self.children, self.dynamic_modules, missing_key=key ) else: - raise ValueError( + raise TypeError( f"Input params type {type(params)} not supported. Should be Tensor, Sequence, or Mapping." ) return valid_params @@ -249,8 +259,8 @@ def from_valid(self, valid_params: Union[Tensor, Sequence, Mapping], local=False for module, value in zip(self.dynamic_modules.values(), valid_params): params.append(module.from_valid(value, local=True)) else: - raise AssertionError( - f"Input params length ({len(params)}) does not match dynamic params length ({len(dynamic_params)}) or number of dynamic children ({len(self.children)})" + raise FillDynamicParamsSequenceError( + self.name, valid_params, dynamic_params, self.dynamic_modules ) elif isinstance(valid_params, Mapping): params = {} @@ -262,11 +272,11 @@ def from_valid(self, valid_params: Union[Tensor, Sequence, Mapping], local=False elif key in self.children and self[key].dynamic: params[key] = self[key].from_valid(valid_params[key]) else: - raise ValueError( - f"Key {key} not found in dynamic modules or {self.name} children" + raise FillDynamicParamsMappingError( + self.name, self.children, self.dynamic_modules, missing_key=key ) else: - raise ValueError( + raise TypeError( f"Input params type {type(valid_params)} not supported. Should be Tensor, Sequence or Mapping." ) return params diff --git a/src/caskade/param.py b/src/caskade/param.py index 535a1c8..5c4c05c 100644 --- a/src/caskade/param.py +++ b/src/caskade/param.py @@ -1,10 +1,13 @@ from typing import Optional, Union, Callable +from warnings import warn import torch from torch import Tensor from torch import pi from .base import Node +from .errors import ParamConfigurationError, ParamTypeError, ActiveStateError +from .warnings import InvalidValueWarning class Param(Node): @@ -65,15 +68,16 @@ def __init__( super().__init__(name=name) if value is None: if shape is None: - raise ValueError("Either value or shape must be provided") + raise ParamConfigurationError("Either value or shape must be provided") if not isinstance(shape, tuple): - raise ValueError("Shape must be a tuple") + raise ParamConfigurationError("Shape must be a tuple") self.shape = shape elif not isinstance(value, (Param, Callable)): value = torch.as_tensor(value) - assert ( - shape == () or shape is None or shape == value.shape - ), f"Shape {shape} does not match value shape {value.shape}" + if not (shape == () or shape is None or shape == value.shape): + raise ParamConfigurationError( + f"Shape {shape} does not match value shape {value.shape}" + ) self.value = value self.cyclic = cyclic self.valid = valid @@ -98,7 +102,7 @@ def shape(self) -> tuple: @shape.setter def shape(self, shape): if self.pointer: - raise RuntimeError("Cannot set shape of parameter with type 'pointer'") + raise ParamTypeError("Cannot set shape of parameter with type 'pointer'") self._shape = shape @property @@ -111,9 +115,7 @@ def value(self) -> Union[Tensor, None]: def value(self, value): # While active no value can be set if self.active: - raise RuntimeError( - f"Cannot set value of parameter {self.name}|{self._type} while active" - ) + raise ActiveStateError(f"Cannot set value of parameter {self.name} while active") # unlink if pointer to avoid floating references if self.pointer: @@ -140,6 +142,10 @@ def value(self, value): value = torch.as_tensor(value) self.shape = value.shape self._value = value + try: + self.valid = self._valid # re-check valid range + except AttributeError: + pass self.update_graph() @@ -185,22 +191,32 @@ def valid(self, valid: tuple[Union[Tensor, float, int, None]]): if valid is None: valid = (None, None) - assert isinstance(valid, tuple) and len(valid) == 2, "Valid must be a tuple of length 2" + if not isinstance(valid, tuple): + raise ParamConfigurationError("Valid must be a tuple") + if len(valid) != 2: + raise ParamConfigurationError("Valid must be a tuple of length 2") if valid == (None, None): - assert not self.cyclic, "Cannot set valid to None for cyclic parameter" + if self.cyclic: + raise ParamConfigurationError("Cannot set valid to None for cyclic parameter") self.to_valid = self._to_valid_base self.from_valid = self._from_valid_base elif valid[0] is None: - assert not self.cyclic, "Cannot set left valid to None for cyclic parameter" + if self.cyclic: + raise ParamConfigurationError("Cannot set left valid to None for cyclic parameter") self.to_valid = self._to_valid_rightvalid self.from_valid = self._from_valid_rightvalid valid = (None, torch.as_tensor(valid[1])) + if self.static and torch.any(self.value > valid[1]): + warn(InvalidValueWarning(self.name, self.value, valid)) elif valid[1] is None: - assert not self.cyclic, "Cannot set right valid to None for cyclic parameter" + if self.cyclic: + raise ParamConfigurationError("Cannot set right valid to None for cyclic parameter") self.to_valid = self._to_valid_leftvalid self.from_valid = self._from_valid_leftvalid valid = (torch.as_tensor(valid[0]), None) + if self.static and torch.any(self.value < valid[0]): + warn(InvalidValueWarning(self.name, self.value, valid)) else: if self.cyclic: self.to_valid = self._to_valid_cyclic @@ -209,12 +225,20 @@ def valid(self, valid: tuple[Union[Tensor, float, int, None]]): self.to_valid = self._to_valid_fullvalid self.from_valid = self._from_valid_fullvalid valid = (torch.as_tensor(valid[0]), torch.as_tensor(valid[1])) + if torch.any(valid[0] >= valid[1]): + raise ParamConfigurationError("Valid range (valid[1] - valid[0]) must be positive") + if ( + self.static + and not self.cyclic + and (torch.any(self.value < valid[0]) or torch.any(self.value > valid[1])) + ): + warn(InvalidValueWarning(self.name, self.value, valid)) self._valid = valid def _to_valid_base(self, value): if self.pointer: - raise ValueError("Cannot apply valid transformation to pointer parameter") + raise ParamTypeError("Cannot apply valid transformation to pointer parameter") return value def _to_valid_fullvalid(self, value): @@ -235,7 +259,7 @@ def _to_valid_rightvalid(self, value): def _from_valid_base(self, value): if self.pointer: - raise ValueError("Cannot apply valid transformation to pointer parameter") + raise ParamTypeError("Cannot apply valid transformation to pointer parameter") return value def _from_valid_fullvalid(self, value): @@ -257,3 +281,6 @@ def _from_valid_rightvalid(self, value): value = self._from_valid_base(value) value = (value + self.valid[1] - ((value - self.valid[1]) ** 2 + 4).sqrt()) / 2 return value + + def __repr__(self): + return self.name diff --git a/src/caskade/warnings.py b/src/caskade/warnings.py new file mode 100644 index 0000000..ac5c763 --- /dev/null +++ b/src/caskade/warnings.py @@ -0,0 +1,17 @@ +from textwrap import dedent + + +class CaskadeWarning(Warning): + """Base warning for Caskade.""" + + +class InvalidValueWarning(CaskadeWarning): + """Warning for values which fall outside the valid range.""" + + def __init__(self, name, value, valid): + message = dedent( + f"""\ + Value {value.detach().cpu().tolist()} for parameter "{name}" is outside the valid range ({valid[0].detach().cpu().tolist() if valid[0] is not None else "-inf"}, {valid[1].detach().cpu().tolist() if valid[1] is not None else "inf"}). + Likely to cause errors or unexpected behavior!""" + ) + super().__init__(message) diff --git a/tests/test_base.py b/tests/test_base.py index dcf1d29..36b27cd 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,4 +1,4 @@ -from caskade import Node, test +from caskade import Node, test, GraphError, NodeConfigurationError import pytest @@ -14,6 +14,12 @@ def test_creation(): with pytest.raises(AttributeError): node.name = "newname" + with pytest.raises(NodeConfigurationError): + node2 = Node(1) + + with pytest.raises(NodeConfigurationError): + node2 = Node("test|test") + def test_link(): node1 = Node("node1") @@ -21,17 +27,17 @@ def test_link(): node1.link("subnode", node2) # Already linked - with pytest.raises(ValueError): + with pytest.raises(GraphError): node1.link("subnode", node2) # Double link - with pytest.raises(ValueError): + with pytest.raises(GraphError): node1.link("subnode2", node2) # Make a cycle node3 = Node("node3") node2.link("subnode", node3) - with pytest.raises(ValueError): + with pytest.raises(GraphError): node3.link("subnode", node1) assert "subnode" in node1._children diff --git a/tests/test_forward.py b/tests/test_forward.py index 6997730..28ebfc4 100644 --- a/tests/test_forward.py +++ b/tests/test_forward.py @@ -1,6 +1,15 @@ import torch -from caskade import Module, Param, forward, ValidContext +from caskade import ( + Module, + Param, + forward, + ValidContext, + FillDynamicParamsSequenceError, + FillDynamicParamsMappingError, + FillDynamicParamsTensorError, + ParamConfigurationError, +) import pytest @@ -56,19 +65,19 @@ def __call__(self, d=None, e=None, live_c=None): assert valid_result.shape == (2, 2) assert torch.all(valid_result == result).item() # Wrong number of params, too few - with pytest.raises(AssertionError): + with pytest.raises(FillDynamicParamsSequenceError): result = main1.testfun(1.0, params=params[:3]) - with pytest.raises(AssertionError): + with pytest.raises(FillDynamicParamsSequenceError): main1.to_valid(params[:3]) - with pytest.raises(AssertionError): + with pytest.raises(FillDynamicParamsSequenceError): main1.from_valid(params[:3]) # Wrong number of params, too many badparams = params + params + params - with pytest.raises(AssertionError): + with pytest.raises(FillDynamicParamsSequenceError): result = main1.testfun(1.0, params=badparams) - with pytest.raises(AssertionError): + with pytest.raises(FillDynamicParamsSequenceError): main1.to_valid(badparams) - with pytest.raises(AssertionError): + with pytest.raises(FillDynamicParamsSequenceError): main1.from_valid(badparams) # List by children @@ -98,10 +107,10 @@ def __call__(self, d=None, e=None, live_c=None): assert valid_result.shape == (2, 2) assert torch.all(valid_result == result).item() # Wrong number of params, too few - with pytest.raises(AssertionError): + with pytest.raises(FillDynamicParamsTensorError): result = main1.testfun(1.0, params[:-3]) # Wrong number of params, too many - with pytest.raises(AssertionError): + with pytest.raises(FillDynamicParamsTensorError): result = main1.testfun(1.0, torch.cat((params, params))) # Batched tensor as params @@ -133,11 +142,11 @@ def __call__(self, d=None, e=None, live_c=None): assert torch.all(valid_result == result).item() # Wrong name for params params = {"q": torch.ones((2, 2)), "TestSubSim": torch.tensor((3.0, 4.0, 1.0))} - with pytest.raises(ValueError): + with pytest.raises(FillDynamicParamsMappingError): result = main1.testfun(1.0, params=params) - with pytest.raises(ValueError): + with pytest.raises(FillDynamicParamsMappingError): main1.to_valid(params) - with pytest.raises(ValueError): + with pytest.raises(FillDynamicParamsMappingError): main1.from_valid(params) # Dict as params, sub element is list @@ -169,6 +178,13 @@ def __call__(self, d=None, e=None, live_c=None): valid_result = main1.testfun(1.0, params=main1.to_valid(params)) assert valid_result.shape == (2, 2) assert torch.all(valid_result == result).item() + # Missing param + params = { + "b": torch.ones((2, 2)), + "TestSubSim": {"d": torch.tensor(3.0), "e": torch.tensor(4.0)}, # , "f": torch.tensor(1.0) + } + with pytest.raises(FillDynamicParamsMappingError): + result = main1.testfun(1.0, params=params) # All params static main1.b = torch.ones((2, 2)) @@ -181,23 +197,15 @@ def __call__(self, d=None, e=None, live_c=None): # dynamic with no shape main1.b = None main1.b.shape = None - with pytest.raises(ValueError): + with pytest.raises(ParamConfigurationError): main1.testfun(1.0, params=torch.ones(4)) result = main1.testfun(1.0, params=[torch.ones((2, 2))]) assert result.shape == (2, 2) - # wrong number of params - with pytest.raises(AssertionError): - main1.testfun(1.0, params=[torch.ones((2, 2)), torch.tensor(3.0)]) - # wrong parameter type - with pytest.raises(ValueError): + with pytest.raises(TypeError): main1.testfun(1.0, params=None) - with pytest.raises(ValueError): + with pytest.raises(TypeError): main1.to_valid(None) - with pytest.raises(ValueError): + with pytest.raises(TypeError): main1.from_valid(None) - - # param key doesn't exist - with pytest.raises(ValueError): - main1.testfun(1.0, params={"q": torch.ones((2, 2))}) diff --git a/tests/test_module.py b/tests/test_module.py index fee6493..38bbeb2 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -1,4 +1,6 @@ -from caskade import Module, Param +from caskade import Module, Param, ActiveStateError + +import pytest def test_module_creation(): @@ -13,6 +15,20 @@ def test_module_creation(): assert m1.dynamic_params == (p1,) assert m2.dynamic_params == (p1,) + m3 = Module("test1") + assert m3.name == "test1_0" + + +def test_module_methods(): + + m1 = Module("test1") + + with pytest.raises(ActiveStateError): + m1.fill_params([1.0, 2.0, 3.0]) + + with pytest.raises(ActiveStateError): + m1.clear_params() + def test_module_del(): m1 = Module("deltest1") diff --git a/tests/test_param.py b/tests/test_param.py index cfed0ae..dd48bca 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -1,7 +1,13 @@ import pytest import torch -from caskade import Param +from caskade import ( + Param, + ActiveStateError, + ParamConfigurationError, + ParamTypeError, + InvalidValueWarning, +) def test_param_creation(): @@ -19,37 +25,44 @@ def test_param_creation(): p3 = Param("test", torch.ones((1, 2, 3))) # Cant update value when active - with pytest.raises(RuntimeError): + with pytest.raises(ActiveStateError): p3.active = True p3.value = 1.0 # Missmatch value and shape - with pytest.raises(AssertionError): + with pytest.raises(ParamConfigurationError): p4 = Param("test", 1.0, shape=(1, 2, 3)) # Cant set shape of pointer or function p5 = Param("test", p3) - with pytest.raises(RuntimeError): + with pytest.raises(ParamTypeError): p5.shape = (1, 2, 3) - with pytest.raises(ValueError): + with pytest.raises(ParamTypeError): p5.to_valid(1.0) - with pytest.raises(ValueError): + with pytest.raises(ParamTypeError): p5.from_valid(1.0) # Function parameter p6 = Param("test", lambda p: p["other"].value * 2) p6.link("other", p2) - with pytest.raises(RuntimeError): + with pytest.raises(ParamTypeError): p6.shape = (1, 2, 3) # Missing value and shape - with pytest.raises(ValueError): + with pytest.raises(ParamConfigurationError): p7 = Param("test", None, None) # Shape is not a tuple - with pytest.raises(ValueError): + with pytest.raises(ParamConfigurationError): p8 = Param("test", None, 7) + # Metadata + p9 = Param("test", 1.0, units="none", cyclic=True, valid=(0, 1)) + assert p9.units == "none" + assert p9.cyclic + assert p9.valid[0].item() == 0 + assert p9.valid[1].item() == 1 + def test_param_to(): p = Param("test", 1.0, valid=(0, 2)) @@ -127,3 +140,25 @@ def test_valid(): assert torch.all( p.from_valid(torch.linspace(-1e4, 1e4, 101)) <= 1 ), "from_valid should map to valid range" + + p.value = 0.5 + + with pytest.raises(ParamConfigurationError): + p.valid = None + with pytest.raises(ParamConfigurationError): + p.valid = (1, None) + with pytest.raises(ParamConfigurationError): + p.valid = (None, 1) + p.cyclic = False + with pytest.raises(ParamConfigurationError): + p.valid = (1, 0) + with pytest.raises(ParamConfigurationError): + p.valid = (0, 1, 2) + with pytest.raises(ParamConfigurationError): + p.valid = [0, 1] + with pytest.warns(InvalidValueWarning): + p.value = -1 + with pytest.warns(InvalidValueWarning): + p.valid = (0, None) + with pytest.warns(InvalidValueWarning): + p.valid = (None, -2)