diff --git a/pyro/contrib/autoname/autoname.py b/pyro/contrib/autoname/autoname.py index 3112b35dd7..4771ddd51b 100644 --- a/pyro/contrib/autoname/autoname.py +++ b/pyro/contrib/autoname/autoname.py @@ -153,9 +153,9 @@ def _pyro_genname(msg): msg["stop"] = True -_handler_name, _handler = _make_handler(AutonameMessenger) -_handler.__module__ = __name__ -locals()[_handler_name] = _handler +@_make_handler(AutonameMessenger, __name__) +def autoname(fn=None, name=None): + ... @singledispatch diff --git a/pyro/contrib/funsor/handlers/__init__.py b/pyro/contrib/funsor/handlers/__init__.py index 80d01740f9..992f027715 100644 --- a/pyro/contrib/funsor/handlers/__init__.py +++ b/pyro/contrib/funsor/handlers/__init__.py @@ -21,17 +21,46 @@ from .replay_messenger import ReplayMessenger from .trace_messenger import TraceMessenger -_msngrs = [ - EnumMessenger, - MarkovMessenger, - NamedMessenger, - PlateMessenger, - ReplayMessenger, - TraceMessenger, - VectorizedMarkovMessenger, -] - -for _msngr_cls in _msngrs: - _handler_name, _handler = _make_handler(_msngr_cls) - _handler.__module__ = __name__ - locals()[_handler_name] = _handler + +@_make_handler(EnumMessenger, __name__) +def enum(fn=None, first_available_dim=None): + ... + + +@_make_handler(MarkovMessenger, __name__) +def markov(fn=None, history=1, keep=False): + ... + + +@_make_handler(NamedMessenger, __name__) +def named(fn=None, first_available_dim=None): + ... + + +@_make_handler(PlateMessenger, __name__) +def plate( + fn=None, + name=None, + size=None, + subsample_size=None, + subsample=None, + dim=None, + use_cuda=None, + device=None, +): + ... + + +@_make_handler(ReplayMessenger, __name__) +def replay(fn=None, trace=None, params=None): + ... + + +@_make_handler(TraceMessenger, __name__) +def trace(fn=None, graph_type=None, param_only=None, pack_online=True): + ... + + +@_make_handler(VectorizedMarkovMessenger, __name__) +def vectorized_markov(fn=None, name=None, size=None, dim=None, history=1): + ... diff --git a/pyro/poutine/handlers.py b/pyro/poutine/handlers.py index 699d5c76c3..af4f36ccb2 100644 --- a/pyro/poutine/handlers.py +++ b/pyro/poutine/handlers.py @@ -51,7 +51,6 @@ import collections import functools -import re from pyro.poutine import util @@ -79,68 +78,130 @@ # Begin primitive operations ############################################ -_msngrs = [ - BlockMessenger, - BroadcastMessenger, - CollapseMessenger, - ConditionMessenger, - DoMessenger, - EnumMessenger, - EscapeMessenger, - InferConfigMessenger, - LiftMessenger, - MaskMessenger, - ReparamMessenger, - ReplayMessenger, - ScaleMessenger, - SeedMessenger, - TraceMessenger, - UnconditionMessenger, - SubstituteMessenger, -] - - -def _make_handler(msngr_cls): - _re1 = re.compile("(.)([A-Z][a-z]+)") - _re2 = re.compile("([a-z0-9])([A-Z])") - - def handler(fn=None, *args, **kwargs): - if fn is not None and not ( - callable(fn) or isinstance(fn, collections.abc.Iterable) - ): - raise ValueError( - "{} is not callable, did you mean to pass it as a keyword arg?".format( - fn + +def _make_handler(msngr_cls, module=None): + def handler_decorator(func): + def handler(fn=None, *args, **kwargs): + if fn is not None and not ( + callable(fn) or isinstance(fn, collections.abc.Iterable) + ): + raise ValueError( + f"{fn} is not callable, did you mean to pass it as a keyword arg?" ) + msngr = msngr_cls(*args, **kwargs) + return ( + functools.update_wrapper(msngr(fn), fn, updated=()) + if fn is not None + else msngr ) - msngr = msngr_cls(*args, **kwargs) - return ( - functools.update_wrapper(msngr(fn), fn, updated=()) - if fn is not None - else msngr - ) - # handler names from messenger names: strip Messenger suffix, convert CamelCase to snake_case - handler_name = _re2.sub( - r"\1_\2", _re1.sub(r"\1_\2", msngr_cls.__name__.split("Messenger")[0]) - ).lower() - handler.__doc__ = ( - """Convenient wrapper of :class:`~pyro.poutine.{}.{}` \n\n""".format( - handler_name + "_messenger", msngr_cls.__name__ + handler.__doc__ = ( + """Convenient wrapper of :class:`~pyro.poutine.{}.{}` \n\n""".format( + func.__name__ + "_messenger", msngr_cls.__name__ + ) + + (msngr_cls.__doc__ if msngr_cls.__doc__ else "") ) - + (msngr_cls.__doc__ if msngr_cls.__doc__ else "") - ) - handler.__name__ = handler_name - return handler_name, handler + handler.__name__ = func.__name__ + if module is not None: + handler.__module__ = module + return handler + + return handler_decorator + + +@_make_handler(BlockMessenger) +def block( + fn=None, + hide_fn=None, + expose_fn=None, + hide_all=True, + expose_all=False, + hide=None, + expose=None, + hide_types=None, + expose_types=None, +): + ... + + +@_make_handler(BroadcastMessenger) +def broadcast(fn=None): + ... + + +@_make_handler(CollapseMessenger) +def collapse(fn=None, *args, **kwargs): + ... + + +@_make_handler(ConditionMessenger) +def condition(fn, data): + ... + + +@_make_handler(DoMessenger) +def do(fn, data): + ... + + +@_make_handler(EnumMessenger) +def enum(fn=None, first_available_dim=None): + ... + + +@_make_handler(EscapeMessenger) +def escape(fn, escape_fn): + ... + + +@_make_handler(InferConfigMessenger) +def infer_config(fn, config_fn): + ... + + +@_make_handler(LiftMessenger) +def lift(fn, prior): + ... + + +@_make_handler(MaskMessenger) +def mask(fn, mask): + ... + + +@_make_handler(ReparamMessenger) +def reparam(fn, config): + ... + + +@_make_handler(ReplayMessenger) +def replay(fn=None, trace=None, params=None): + ... + + +@_make_handler(ScaleMessenger) +def scale(fn, scale): + ... + + +@_make_handler(SeedMessenger) +def seed(fn, rng_seed): + ... + + +@_make_handler(TraceMessenger) +def trace(fn=None, graph_type=None, param_only=None): + ... + +@_make_handler(UnconditionMessenger) +def uncondition(fn=None): + ... -trace = None # flake8 -escape = None # flake8 -for _msngr_cls in _msngrs: - _handler_name, _handler = _make_handler(_msngr_cls) - _handler.__module__ = __name__ - locals()[_handler_name] = _handler +@_make_handler(SubstituteMessenger) +def substitute(fn, data): + ... #########################################