From df34383cfe8144234c0b24a212464531883eaf6e Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 2 Feb 2024 15:58:53 -0800 Subject: [PATCH] Get CI working with new black, torch, pytest (#3318) --- examples/capture_recapture/cjs.py | 8 ++- examples/contrib/funsor/hmm.py | 8 ++- examples/hmm.py | 8 ++- examples/mixed_hmm/experiment.py | 8 +-- pyro/contrib/autoname/autoname.py | 3 +- pyro/contrib/bnn/hidden_layer.py | 1 + pyro/contrib/cevae/__init__.py | 2 +- pyro/contrib/epidemiology/models.py | 8 ++- pyro/contrib/funsor/handlers/__init__.py | 21 +++---- .../funsor/handlers/replay_messenger.py | 6 +- pyro/contrib/funsor/infer/trace_elbo.py | 8 ++- pyro/contrib/funsor/infer/traceenum_elbo.py | 28 +++++---- pyro/contrib/funsor/infer/tracetmc_elbo.py | 14 +++-- pyro/contrib/timeseries/lgssmgp.py | 30 +++++----- pyro/contrib/tracking/distributions.py | 1 + pyro/distributions/__init__.py | 4 +- pyro/distributions/asymmetriclaplace.py | 4 +- pyro/distributions/conditional.py | 16 +++-- pyro/distributions/conjugate.py | 2 + pyro/distributions/constraints.py | 14 +++-- pyro/distributions/grouped_normal_normal.py | 1 + pyro/distributions/hmm.py | 1 + pyro/distributions/inverse_gamma.py | 2 + pyro/distributions/lkj.py | 1 + .../log_normal_negative_binomial.py | 1 + pyro/distributions/one_one_matching.py | 1 + pyro/distributions/one_two_matching.py | 1 + pyro/distributions/sine_skewed.py | 8 ++- pyro/distributions/stable.py | 1 + pyro/distributions/torch.py | 3 + pyro/distributions/transforms/basic.py | 2 + .../transforms/block_autoregressive.py | 1 + pyro/distributions/transforms/cholesky.py | 2 + pyro/distributions/transforms/permute.py | 1 + pyro/distributions/transforms/power.py | 1 + pyro/distributions/transforms/softplus.py | 1 + pyro/infer/autoguide/gaussian.py | 6 +- pyro/infer/rws.py | 9 ++- pyro/nn/module.py | 6 +- pyro/params/param_store.py | 18 +++--- pyro/poutine/enum_messenger.py | 18 +++--- pyro/poutine/handlers.py | 60 +++++++------------ pyro/poutine/runtime.py | 6 +- pyro/poutine/trace_struct.py | 8 ++- pyro/primitives.py | 6 +- tests/contrib/funsor/test_tmc.py | 8 ++- tests/infer/test_smcfilter.py | 4 +- tests/infer/test_tmc.py | 17 +++--- tests/ops/test_welford.py | 1 + tests/perf/test_benchmark.py | 2 + 50 files changed, 208 insertions(+), 183 deletions(-) diff --git a/examples/capture_recapture/cjs.py b/examples/capture_recapture/cjs.py index a0923f2500..193b5631c5 100644 --- a/examples/capture_recapture/cjs.py +++ b/examples/capture_recapture/cjs.py @@ -324,9 +324,11 @@ def expose_fn(msg): elbo = TraceTMC_ELBO(max_plate_nesting=1) tmc_model = poutine.infer_config( model, - lambda msg: {"num_samples": args.tmc_num_samples, "expand": False} - if msg["infer"].get("enumerate", None) == "parallel" - else {}, + lambda msg: ( + {"num_samples": args.tmc_num_samples, "expand": False} + if msg["infer"].get("enumerate", None) == "parallel" + else {} + ), ) # noqa: E501 svi = SVI(tmc_model, guide, optim, elbo) else: diff --git a/examples/contrib/funsor/hmm.py b/examples/contrib/funsor/hmm.py index 6756057b51..6df3e87aa1 100644 --- a/examples/contrib/funsor/hmm.py +++ b/examples/contrib/funsor/hmm.py @@ -737,9 +737,11 @@ def main(args): elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2) tmc_model = handlers.infer_config( model, - lambda msg: {"num_samples": args.tmc_num_samples, "expand": False} - if msg["infer"].get("enumerate", None) == "parallel" - else {}, + lambda msg: ( + {"num_samples": args.tmc_num_samples, "expand": False} + if msg["infer"].get("enumerate", None) == "parallel" + else {} + ), ) # noqa: E501 svi = infer.SVI(tmc_model, guide, optimizer, elbo) else: diff --git a/examples/hmm.py b/examples/hmm.py index 0c0c4418e1..25dccfda35 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -677,9 +677,11 @@ def main(args): elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2) tmc_model = poutine.infer_config( model, - lambda msg: {"num_samples": args.tmc_num_samples, "expand": False} - if msg["infer"].get("enumerate", None) == "parallel" - else {}, + lambda msg: ( + {"num_samples": args.tmc_num_samples, "expand": False} + if msg["infer"].get("enumerate", None) == "parallel" + else {} + ), ) # noqa: E501 svi = SVI(tmc_model, guide, optim, elbo) else: diff --git a/examples/mixed_hmm/experiment.py b/examples/mixed_hmm/experiment.py index 65f900c9d1..460a5122a2 100644 --- a/examples/mixed_hmm/experiment.py +++ b/examples/mixed_hmm/experiment.py @@ -143,16 +143,12 @@ def closure(): re_str = "g" + ( "n" if args["group"] is None - else "d" - if args["group"] == "discrete" - else "c" + else "d" if args["group"] == "discrete" else "c" ) re_str += "i" + ( "n" if args["individual"] is None - else "d" - if args["individual"] == "discrete" - else "c" + else "d" if args["individual"] == "discrete" else "c" ) results_filename = "expt_{}_{}_{}.json".format( dataset, re_str, str(uuid.uuid4().hex)[0:5] diff --git a/pyro/contrib/autoname/autoname.py b/pyro/contrib/autoname/autoname.py index 4771ddd51b..f52aec84fb 100644 --- a/pyro/contrib/autoname/autoname.py +++ b/pyro/contrib/autoname/autoname.py @@ -154,8 +154,7 @@ def _pyro_genname(msg): @_make_handler(AutonameMessenger, __name__) -def autoname(fn=None, name=None): - ... +def autoname(fn=None, name=None): ... @singledispatch diff --git a/pyro/contrib/bnn/hidden_layer.py b/pyro/contrib/bnn/hidden_layer.py index 43000e820e..85c6a786aa 100644 --- a/pyro/contrib/bnn/hidden_layer.py +++ b/pyro/contrib/bnn/hidden_layer.py @@ -55,6 +55,7 @@ class HiddenLayer(TorchDistribution): "Variational dropout and the local reparameterization trick." Advances in Neural Information Processing Systems. 2015. """ + has_rsample = True def __init__( diff --git a/pyro/contrib/cevae/__init__.py b/pyro/contrib/cevae/__init__.py index 842a4e2a53..dd6da388a8 100644 --- a/pyro/contrib/cevae/__init__.py +++ b/pyro/contrib/cevae/__init__.py @@ -271,7 +271,7 @@ def __init__(self, data): super().__init__() with torch.no_grad(): loc = data.mean(0) - scale = data.std(0) + scale = data.std(0, unbiased=False) scale[~(scale > 0)] = 1.0 self.register_buffer("loc", loc) self.register_buffer("inv_scale", scale.reciprocal()) diff --git a/pyro/contrib/epidemiology/models.py b/pyro/contrib/epidemiology/models.py index ddfb515a9d..cf2a8513c2 100644 --- a/pyro/contrib/epidemiology/models.py +++ b/pyro/contrib/epidemiology/models.py @@ -680,9 +680,11 @@ def transition(self, params, state, t): coal_rate = R * (1.0 + 1.0 / k) / (tau_i * state["I"] + 1e-8) pyro.factor( "coalescent_{}".format(t), - self.coal_likelihood(coal_rate, t) - if t_is_observed - else torch.tensor(0.0), + ( + self.coal_likelihood(coal_rate, t) + if t_is_observed + else torch.tensor(0.0) + ), ) # Update compartements with flows. diff --git a/pyro/contrib/funsor/handlers/__init__.py b/pyro/contrib/funsor/handlers/__init__.py index 992f027715..a05bc55517 100644 --- a/pyro/contrib/funsor/handlers/__init__.py +++ b/pyro/contrib/funsor/handlers/__init__.py @@ -23,18 +23,15 @@ @_make_handler(EnumMessenger, __name__) -def enum(fn=None, first_available_dim=None): - ... +def enum(fn=None, first_available_dim=None): ... @_make_handler(MarkovMessenger, __name__) -def markov(fn=None, history=1, keep=False): - ... +def markov(fn=None, history=1, keep=False): ... @_make_handler(NamedMessenger, __name__) -def named(fn=None, first_available_dim=None): - ... +def named(fn=None, first_available_dim=None): ... @_make_handler(PlateMessenger, __name__) @@ -47,20 +44,16 @@ def plate( dim=None, use_cuda=None, device=None, -): - ... +): ... @_make_handler(ReplayMessenger, __name__) -def replay(fn=None, trace=None, params=None): - ... +def replay(fn=None, trace=None, params=None): ... @_make_handler(TraceMessenger, __name__) -def trace(fn=None, graph_type=None, param_only=None, pack_online=True): - ... +def trace(fn=None, graph_type=None, param_only=None, pack_online=True): ... @_make_handler(VectorizedMarkovMessenger, __name__) -def vectorized_markov(fn=None, name=None, size=None, dim=None, history=1): - ... +def vectorized_markov(fn=None, name=None, size=None, dim=None, history=1): ... diff --git a/pyro/contrib/funsor/handlers/replay_messenger.py b/pyro/contrib/funsor/handlers/replay_messenger.py index 2f6a76033a..35cb5d9c79 100644 --- a/pyro/contrib/funsor/handlers/replay_messenger.py +++ b/pyro/contrib/funsor/handlers/replay_messenger.py @@ -14,9 +14,9 @@ class ReplayMessenger(OrigReplayMessenger): def _pyro_sample(self, msg): name = msg["name"] - msg[ - "replay_active" - ] = True # indicate replaying so importance weights can be scaled + msg["replay_active"] = ( + True # indicate replaying so importance weights can be scaled + ) if self.trace is None: return diff --git a/pyro/contrib/funsor/infer/trace_elbo.py b/pyro/contrib/funsor/infer/trace_elbo.py index e91787732f..7b574d144a 100644 --- a/pyro/contrib/funsor/infer/trace_elbo.py +++ b/pyro/contrib/funsor/infer/trace_elbo.py @@ -18,9 +18,11 @@ @copy_docs_from(_OrigTrace_ELBO) class Trace_ELBO(ELBO): def differentiable_loss(self, model, guide, *args, **kwargs): - with enum(), plate( - size=self.num_particles - ) if self.num_particles > 1 else contextlib.ExitStack(): + with enum(), ( + plate(size=self.num_particles) + if self.num_particles > 1 + else contextlib.ExitStack() + ): guide_tr = trace(config_enumerate(default="flat")(guide)).get_trace( *args, **kwargs ) diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index 0680aa2ceb..95b2562b54 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -93,12 +93,14 @@ def terms_from_trace(tr): class TraceMarkovEnum_ELBO(ELBO): def differentiable_loss(self, model, guide, *args, **kwargs): # get batched, enumerated, to_funsor-ed traces from the guide and model - with plate( - size=self.num_particles - ) if self.num_particles > 1 else contextlib.ExitStack(), enum( - first_available_dim=(-self.max_plate_nesting - 1) - if self.max_plate_nesting - else None + with ( + plate(size=self.num_particles) + if self.num_particles > 1 + else contextlib.ExitStack() + ), enum( + first_available_dim=( + (-self.max_plate_nesting - 1) if self.max_plate_nesting else None + ) ): guide_tr = trace(guide).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) @@ -170,12 +172,14 @@ def differentiable_loss(self, model, guide, *args, **kwargs): class TraceEnum_ELBO(ELBO): def differentiable_loss(self, model, guide, *args, **kwargs): # get batched, enumerated, to_funsor-ed traces from the guide and model - with plate( - size=self.num_particles - ) if self.num_particles > 1 else contextlib.ExitStack(), enum( - first_available_dim=(-self.max_plate_nesting - 1) - if self.max_plate_nesting - else None + with ( + plate(size=self.num_particles) + if self.num_particles > 1 + else contextlib.ExitStack() + ), enum( + first_available_dim=( + (-self.max_plate_nesting - 1) if self.max_plate_nesting else None + ) ): guide_tr = trace(guide).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) diff --git a/pyro/contrib/funsor/infer/tracetmc_elbo.py b/pyro/contrib/funsor/infer/tracetmc_elbo.py index 7cf4ba805e..4698cb922f 100644 --- a/pyro/contrib/funsor/infer/tracetmc_elbo.py +++ b/pyro/contrib/funsor/infer/tracetmc_elbo.py @@ -16,12 +16,14 @@ @copy_docs_from(_OrigTraceTMC_ELBO) class TraceTMC_ELBO(ELBO): def differentiable_loss(self, model, guide, *args, **kwargs): - with plate( - size=self.num_particles - ) if self.num_particles > 1 else contextlib.ExitStack(), enum( - first_available_dim=(-self.max_plate_nesting - 1) - if self.max_plate_nesting - else None + with ( + plate(size=self.num_particles) + if self.num_particles > 1 + else contextlib.ExitStack() + ), enum( + first_available_dim=( + (-self.max_plate_nesting - 1) if self.max_plate_nesting else None + ) ): guide_tr = trace(guide).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) diff --git a/pyro/contrib/timeseries/lgssmgp.py b/pyro/contrib/timeseries/lgssmgp.py index 640a257dbb..3daeaa5b5b 100644 --- a/pyro/contrib/timeseries/lgssmgp.py +++ b/pyro/contrib/timeseries/lgssmgp.py @@ -109,9 +109,9 @@ def _get_init_dist(self): covar[: self.full_gp_state_dim, : self.full_gp_state_dim] = block_diag_embed( self.kernel.stationary_covariance() ) - covar[ - self.full_gp_state_dim :, self.full_gp_state_dim : - ] = self.init_noise_scale_sq.diag_embed() + covar[self.full_gp_state_dim :, self.full_gp_state_dim :] = ( + self.init_noise_scale_sq.diag_embed() + ) return MultivariateNormal(loc, covar) def _get_obs_dist(self): @@ -134,23 +134,23 @@ def get_dist(self, duration=None): trans_covar = self.z_trans_matrix.new_zeros( self.full_state_dim, self.full_state_dim ) - trans_covar[ - : self.full_gp_state_dim, : self.full_gp_state_dim - ] = block_diag_embed(gp_process_covar) - trans_covar[ - self.full_gp_state_dim :, self.full_gp_state_dim : - ] = self.trans_noise_scale_sq.diag_embed() + trans_covar[: self.full_gp_state_dim, : self.full_gp_state_dim] = ( + block_diag_embed(gp_process_covar) + ) + trans_covar[self.full_gp_state_dim :, self.full_gp_state_dim :] = ( + self.trans_noise_scale_sq.diag_embed() + ) trans_dist = MultivariateNormal( trans_covar.new_zeros(self.full_state_dim), trans_covar ) full_trans_mat = trans_covar.new_zeros(self.full_state_dim, self.full_state_dim) - full_trans_mat[ - : self.full_gp_state_dim, : self.full_gp_state_dim - ] = block_diag_embed(gp_trans_matrix) - full_trans_mat[ - self.full_gp_state_dim :, self.full_gp_state_dim : - ] = self.z_trans_matrix + full_trans_mat[: self.full_gp_state_dim, : self.full_gp_state_dim] = ( + block_diag_embed(gp_trans_matrix) + ) + full_trans_mat[self.full_gp_state_dim :, self.full_gp_state_dim :] = ( + self.z_trans_matrix + ) return dist.GaussianHMM( self._get_init_dist(), diff --git a/pyro/contrib/tracking/distributions.py b/pyro/contrib/tracking/distributions.py index ab2c1a8230..e2be7582c7 100644 --- a/pyro/contrib/tracking/distributions.py +++ b/pyro/contrib/tracking/distributions.py @@ -27,6 +27,7 @@ class EKFDistribution(TorchDistribution): :param dt: time step :type dt: torch.Tensor """ + arg_constraints = { "measurement_cov": constraints.positive_definite, "P0": constraints.positive_definite, diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 71e2f8939c..b648d0d66a 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -48,7 +48,9 @@ LinearHMM, ) from pyro.distributions.improper_uniform import ImproperUniform -from pyro.distributions.inverse_gamma import InverseGamma + +if "InverseGamma" not in locals(): # Use PyTorch version if available. + from pyro.distributions.inverse_gamma import InverseGamma from pyro.distributions.lkj import LKJ, LKJCorrCholesky from pyro.distributions.log_normal_negative_binomial import LogNormalNegativeBinomial from pyro.distributions.logistic import Logistic, SkewLogistic diff --git a/pyro/distributions/asymmetriclaplace.py b/pyro/distributions/asymmetriclaplace.py index e0c86341d9..de64e9fe68 100644 --- a/pyro/distributions/asymmetriclaplace.py +++ b/pyro/distributions/asymmetriclaplace.py @@ -199,9 +199,7 @@ def variance(self): total = left + right p = left / total q = right / total - return ( - p * left**2 + q * right**2 + p * q * total**2 + self.soft_scale**2 - ) + return p * left**2 + q * right**2 + p * q * total**2 + self.soft_scale**2 def _logerfc(x): diff --git a/pyro/distributions/conditional.py b/pyro/distributions/conditional.py index f337962c14..bc3548ed99 100644 --- a/pyro/distributions/conditional.py +++ b/pyro/distributions/conditional.py @@ -82,9 +82,11 @@ class ConditionalComposeTransformModule( def __init__(self, transforms, cache_size: int = 0): self.transforms = [ - ConstantConditionalTransform(t) - if not isinstance(t, ConditionalTransform) - else t + ( + ConstantConditionalTransform(t) + if not isinstance(t, ConditionalTransform) + else t + ) for t in transforms ] super().__init__() @@ -131,9 +133,11 @@ def __init__(self, base_dist, transforms): else ConstantConditionalDistribution(base_dist) ) self.transforms = [ - t - if isinstance(t, ConditionalTransform) - else ConstantConditionalTransform(t) + ( + t + if isinstance(t, ConditionalTransform) + else ConstantConditionalTransform(t) + ) for t in transforms ] diff --git a/pyro/distributions/conjugate.py b/pyro/distributions/conjugate.py index 8a0c3f8471..ffb37f3148 100644 --- a/pyro/distributions/conjugate.py +++ b/pyro/distributions/conjugate.py @@ -47,6 +47,7 @@ class BetaBinomial(TorchDistribution): :param total_count: Number of Bernoulli trials. :type total_count: float or torch.Tensor """ + arg_constraints = { "concentration1": constraints.positive, "concentration0": constraints.positive, @@ -150,6 +151,7 @@ class DirichletMultinomial(TorchDistribution): :param bool is_sparse: Whether to assume value is mostly zero when computing :meth:`log_prob`, which can speed up computation when data is sparse. """ + arg_constraints = { "concentration": constraints.independent(constraints.positive, 1), "total_count": constraints.nonnegative_integer, diff --git a/pyro/distributions/constraints.py b/pyro/distributions/constraints.py index 1aaf0656aa..3f8026f2e0 100644 --- a/pyro/distributions/constraints.py +++ b/pyro/distributions/constraints.py @@ -158,12 +158,14 @@ def check(self, value): {} """.format( _name, - "alias of :class:`torch.distributions.constraints.{}`".format(_name) - if globals()[_name].__module__.startswith("torch") - else ".. autoclass:: {}".format( - _name - if type(globals()[_name]) is type - else type(globals()[_name]).__name__ + ( + "alias of :class:`torch.distributions.constraints.{}`".format(_name) + if globals()[_name].__module__.startswith("torch") + else ".. autoclass:: {}".format( + _name + if type(globals()[_name]) is type + else type(globals()[_name]).__name__ + ) ), ) for _name in sorted(__all__) diff --git a/pyro/distributions/grouped_normal_normal.py b/pyro/distributions/grouped_normal_normal.py index eaf80142fe..05d488ee4a 100644 --- a/pyro/distributions/grouped_normal_normal.py +++ b/pyro/distributions/grouped_normal_normal.py @@ -48,6 +48,7 @@ class GroupedNormalNormal(TorchDistribution): :param torch.LongTensor group_idx: Tensor of indices of shape `(num_data,)` linking each observation to one of the `num_groups` groups that are specified in `prior_loc` and `prior_scale`. """ + arg_constraints = { "prior_loc": constraints.real, "prior_scale": constraints.positive, diff --git a/pyro/distributions/hmm.py b/pyro/distributions/hmm.py index 2112d9979b..9e5d714aa4 100644 --- a/pyro/distributions/hmm.py +++ b/pyro/distributions/hmm.py @@ -1003,6 +1003,7 @@ class LinearHMM(HiddenMarkovModel): This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ + arg_constraints = {} support = constraints.independent(constraints.real, 2) has_rsample = True diff --git a/pyro/distributions/inverse_gamma.py b/pyro/distributions/inverse_gamma.py index 6235aa3000..b739ce00df 100644 --- a/pyro/distributions/inverse_gamma.py +++ b/pyro/distributions/inverse_gamma.py @@ -7,6 +7,7 @@ from pyro.distributions.torch import Gamma, TransformedDistribution +# DEPRECATED in favor of torch.distributions.InverseGamma. class InverseGamma(TransformedDistribution): r""" Creates an inverse-gamma distribution parameterized by @@ -18,6 +19,7 @@ class InverseGamma(TransformedDistribution): :param torch.Tensor concentration: the concentration parameter (i.e. alpha). :param torch.Tensor rate: the rate parameter (i.e. beta). """ + arg_constraints = { "concentration": constraints.positive, "rate": constraints.positive, diff --git a/pyro/distributions/lkj.py b/pyro/distributions/lkj.py index 720c1eb0cf..a7ba243d92 100644 --- a/pyro/distributions/lkj.py +++ b/pyro/distributions/lkj.py @@ -42,6 +42,7 @@ class LKJ(TransformedDistribution): [1] `Generating random correlation matrices based on vines and extended onion method`, Daniel Lewandowski, Dorota Kurowicka, Harry Joe """ + arg_constraints = {"concentration": constraints.positive} support = constraints.corr_matrix diff --git a/pyro/distributions/log_normal_negative_binomial.py b/pyro/distributions/log_normal_negative_binomial.py index 0b1b23b8e1..836d5db786 100644 --- a/pyro/distributions/log_normal_negative_binomial.py +++ b/pyro/distributions/log_normal_negative_binomial.py @@ -66,6 +66,7 @@ class LogNormalNegativeBinomial(TorchDistribution): :param int num_quad_points: Number of quadrature points used to compute the (approximate) `log_prob`. Defaults to 8. """ + arg_constraints = { "total_count": constraints.greater_than_eq(0), "logits": constraints.real, diff --git a/pyro/distributions/one_one_matching.py b/pyro/distributions/one_one_matching.py index 89400bc0c4..cb6e776b28 100644 --- a/pyro/distributions/one_one_matching.py +++ b/pyro/distributions/one_one_matching.py @@ -77,6 +77,7 @@ class OneOneMatching(TorchDistribution): :param int bp_iters: Optional number of belief propagation iterations. If unspecified or ``None`` expensive exact algorithms will be used. """ + arg_constraints = {"logits": constraints.real} has_enumerate_support = True diff --git a/pyro/distributions/one_two_matching.py b/pyro/distributions/one_two_matching.py index cf2bd7a6b2..c1ff43247b 100644 --- a/pyro/distributions/one_two_matching.py +++ b/pyro/distributions/one_two_matching.py @@ -78,6 +78,7 @@ class OneTwoMatching(TorchDistribution): :param int bp_iters: Optional number of belief propagation iterations. If unspecified or ``None`` expensive exact algorithms will be used. """ + arg_constraints = {"logits": constraints.real} has_enumerate_support = True diff --git a/pyro/distributions/sine_skewed.py b/pyro/distributions/sine_skewed.py index e705e7de64..202c00c0d6 100644 --- a/pyro/distributions/sine_skewed.py +++ b/pyro/distributions/sine_skewed.py @@ -114,9 +114,11 @@ def __repr__(self): [ "{}: {}".format( p, - getattr(self, p) - if getattr(self, p).numel() == 1 - else getattr(self, p).size(), + ( + getattr(self, p) + if getattr(self, p).numel() == 1 + else getattr(self, p).size() + ), ) for p in self.arg_constraints.keys() ] diff --git a/pyro/distributions/stable.py b/pyro/distributions/stable.py index 6f3563a30a..0b2ec5b9c0 100644 --- a/pyro/distributions/stable.py +++ b/pyro/distributions/stable.py @@ -144,6 +144,7 @@ def model(): :param str coords: Either "S0" (default) to use Nolan's continuous S0 parametrization, or "S" to use the discontinuous parameterization. """ + has_rsample = True arg_constraints = { "stability": constraints.interval(0, 2), # half-open (0, 2] diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index 805b1b83b6..902602de1a 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -385,5 +385,8 @@ def _cat_docstrings(*docstrings): _name ) for _name in sorted(__all__) + # Work around sphinx autodoc error in case two InverseGamma's are defined: + # "duplicate object description of pyro.distributions.InverseGamma" + if _name != "InverseGamma" ] ) diff --git a/pyro/distributions/transforms/basic.py b/pyro/distributions/transforms/basic.py index 80ac844c53..420cf7cd05 100644 --- a/pyro/distributions/transforms/basic.py +++ b/pyro/distributions/transforms/basic.py @@ -16,6 +16,7 @@ class ELUTransform(Transform): r""" Bijective transform via the mapping :math:`y = \text{ELU}(x)`. """ + domain = constraints.real codomain = constraints.positive bijective = True @@ -52,6 +53,7 @@ class LeakyReLUTransform(Transform): r""" Bijective transform via the mapping :math:`y = \text{LeakyReLU}(x)`. """ + domain = constraints.real codomain = constraints.real bijective = True diff --git a/pyro/distributions/transforms/block_autoregressive.py b/pyro/distributions/transforms/block_autoregressive.py index 1978e10656..bdafb05d5b 100644 --- a/pyro/distributions/transforms/block_autoregressive.py +++ b/pyro/distributions/transforms/block_autoregressive.py @@ -69,6 +69,7 @@ class BlockAutoregressive(TransformModule): [arXiv:1904.04676] """ + domain = constraints.real_vector codomain = constraints.real_vector bijective = True diff --git a/pyro/distributions/transforms/cholesky.py b/pyro/distributions/transforms/cholesky.py index 39321d1145..c174c40d4a 100644 --- a/pyro/distributions/transforms/cholesky.py +++ b/pyro/distributions/transforms/cholesky.py @@ -24,6 +24,7 @@ class CholeskyTransform(Transform): Transform via the mapping :math:`y = safe_cholesky(x)`, where `x` is a positive definite matrix. """ + bijective = True domain = constraints.positive_definite codomain = constraints.lower_cholesky @@ -51,6 +52,7 @@ class CorrMatrixCholeskyTransform(CholeskyTransform): Transform via the mapping :math:`y = safe_cholesky(x)`, where `x` is a correlation matrix. """ + bijective = True domain = constraints.corr_matrix # TODO: change corr_cholesky_constraint to corr_cholesky when the latter is availabler diff --git a/pyro/distributions/transforms/permute.py b/pyro/distributions/transforms/permute.py index 748b567443..d7a520927f 100644 --- a/pyro/distributions/transforms/permute.py +++ b/pyro/distributions/transforms/permute.py @@ -43,6 +43,7 @@ class Permute(Transform): :type dim: int """ + bijective = True volume_preserving = True diff --git a/pyro/distributions/transforms/power.py b/pyro/distributions/transforms/power.py index eb76476080..62dcce821f 100644 --- a/pyro/distributions/transforms/power.py +++ b/pyro/distributions/transforms/power.py @@ -17,6 +17,7 @@ class PositivePowerTransform(Transform): .. warning:: The Jacobian is typically zero or infinite at the origin. """ + domain = constraints.real codomain = constraints.real bijective = True diff --git a/pyro/distributions/transforms/softplus.py b/pyro/distributions/transforms/softplus.py index 857398a42c..1b92c8d3b3 100644 --- a/pyro/distributions/transforms/softplus.py +++ b/pyro/distributions/transforms/softplus.py @@ -15,6 +15,7 @@ class SoftplusTransform(Transform): r""" Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`. """ + domain = constraints.real codomain = constraints.positive bijective = True diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index a6e7bb7bfc..de873471ea 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -375,9 +375,9 @@ def _setup_prototype(self, *args, **kwargs): u_stop = u_start + u_index.size(-1) v_start = local_offsets[v] v_stop = v_start + v_index.size(-1) - index2[ - ..., u_start:u_stop, v_start:v_stop - ] = self._dense_size * u_index.unsqueeze(-1) + v_index.unsqueeze(-2) + index2[..., u_start:u_stop, v_start:v_stop] = ( + self._dense_size * u_index.unsqueeze(-1) + v_index.unsqueeze(-2) + ) self._dense_scatter[d] = index1.reshape(-1), index2.reshape(-1) diff --git a/pyro/infer/rws.py b/pyro/infer/rws.py index 6ec7f28c08..c897cb895d 100644 --- a/pyro/infer/rws.py +++ b/pyro/infer/rws.py @@ -203,9 +203,12 @@ def _loss(self, model, guide, args, kwargs): phi_loss = ( sleep_phi_loss if self.insomnia == 0 - else wake_phi_loss - if self.insomnia == 1 - else self.insomnia * wake_phi_loss + (1.0 - self.insomnia) * sleep_phi_loss + else ( + wake_phi_loss + if self.insomnia == 1 + else self.insomnia * wake_phi_loss + + (1.0 - self.insomnia) * sleep_phi_loss + ) ) return wake_theta_loss, phi_loss diff --git a/pyro/nn/module.py b/pyro/nn/module.py index b3dd27e21e..cc38517ac4 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -488,9 +488,9 @@ def __getattr__(self, name): unconstrained_value = torch.nn.Parameter( unconstrained_value ) - _PYRO_PARAM_STORE._params[ - fullname - ] = unconstrained_value + _PYRO_PARAM_STORE._params[fullname] = ( + unconstrained_value + ) _PYRO_PARAM_STORE._param_to_name[ unconstrained_value ] = fullname diff --git a/pyro/params/param_store.py b/pyro/params/param_store.py index 4e2ba9e74d..99946e5821 100644 --- a/pyro/params/param_store.py +++ b/pyro/params/param_store.py @@ -60,15 +60,15 @@ def __init__(self) -> None: """ initialize ParamStore data structures """ - self._params: Dict[ - str, torch.Tensor - ] = {} # dictionary from param name to param - self._param_to_name: Dict[ - torch.Tensor, str - ] = {} # dictionary from unconstrained param to param name - self._constraints: Dict[ - str, constraints.Constraint - ] = {} # dictionary from param name to constraint object + self._params: Dict[str, torch.Tensor] = ( + {} + ) # dictionary from param name to param + self._param_to_name: Dict[torch.Tensor, str] = ( + {} + ) # dictionary from unconstrained param to param name + self._constraints: Dict[str, constraints.Constraint] = ( + {} + ) # dictionary from param name to constraint object def clear(self) -> None: """ diff --git a/pyro/poutine/enum_messenger.py b/pyro/poutine/enum_messenger.py index fe87f76d7f..408669e900 100644 --- a/pyro/poutine/enum_messenger.py +++ b/pyro/poutine/enum_messenger.py @@ -154,15 +154,15 @@ def __init__(self, first_available_dim: Optional[int] = None) -> None: def __enter__(self) -> Self: if self.first_available_dim is not None: _ENUM_ALLOCATOR.set_first_available_dim(self.first_available_dim) - self._markov_depths: Dict[ - str, int - ] = {} # site name -> depth (nonnegative integer) - self._param_dims: Dict[ - str, Dict[int, int] - ] = {} # site name -> (enum dim -> unique id) - self._value_dims: Dict[ - str, Dict[int, int] - ] = {} # site name -> (enum dim -> unique id) + self._markov_depths: Dict[str, int] = ( + {} + ) # site name -> depth (nonnegative integer) + self._param_dims: Dict[str, Dict[int, int]] = ( + {} + ) # site name -> (enum dim -> unique id) + self._value_dims: Dict[str, Dict[int, int]] = ( + {} + ) # site name -> (enum dim -> unique id) return super().__enter__() @ignore_jit_warnings() diff --git a/pyro/poutine/handlers.py b/pyro/poutine/handlers.py index f0116b9be2..2b93fda743 100644 --- a/pyro/poutine/handlers.py +++ b/pyro/poutine/handlers.py @@ -126,88 +126,71 @@ def block( expose=None, hide_types=None, expose_types=None, -): - ... +): ... @_make_handler(BroadcastMessenger) -def broadcast(fn=None): - ... +def broadcast(fn=None): ... @_make_handler(CollapseMessenger) -def collapse(fn=None, *args, **kwargs): - ... +def collapse(fn=None, *args, **kwargs): ... @_make_handler(ConditionMessenger) -def condition(fn, data): - ... +def condition(fn, data): ... @_make_handler(DoMessenger) -def do(fn, data): - ... +def do(fn, data): ... @_make_handler(EnumMessenger) -def enum(fn=None, first_available_dim=None): - ... +def enum(fn=None, first_available_dim=None): ... @_make_handler(EscapeMessenger) -def escape(fn, escape_fn): - ... +def escape(fn, escape_fn): ... @_make_handler(InferConfigMessenger) -def infer_config(fn, config_fn): - ... +def infer_config(fn, config_fn): ... @_make_handler(LiftMessenger) -def lift(fn, prior): - ... +def lift(fn, prior): ... @_make_handler(MaskMessenger) -def mask(fn, mask): - ... +def mask(fn, mask): ... @_make_handler(ReparamMessenger) -def reparam(fn, config): - ... +def reparam(fn, config): ... @_make_handler(ReplayMessenger) -def replay(fn=None, trace=None, params=None): - ... +def replay(fn=None, trace=None, params=None): ... @_make_handler(ScaleMessenger) -def scale(fn, scale): - ... +def scale(fn, scale): ... @_make_handler(SeedMessenger) -def seed(fn, rng_seed): - ... +def seed(fn, rng_seed): ... @_make_handler(TraceMessenger) -def trace(fn=None, graph_type=None, param_only=None): - ... +def trace(fn=None, graph_type=None, param_only=None): ... @_make_handler(UnconditionMessenger) -def uncondition(fn=None): - ... +def uncondition(fn=None): ... @_make_handler(SubstituteMessenger) -def substitute(fn, data): - ... +def substitute(fn, data): ... ######################################### @@ -289,8 +272,7 @@ def markov( keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None, -) -> MarkovMessenger: - ... +) -> MarkovMessenger: ... @overload @@ -300,8 +282,7 @@ def markov( keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None, -) -> MarkovMessenger: - ... +) -> MarkovMessenger: ... @overload @@ -311,8 +292,7 @@ def markov( keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None, -) -> Callable[_P, _T]: - ... +) -> Callable[_P, _T]: ... def markov( diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index ce4e405752..793b31f465 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -309,15 +309,13 @@ def am_i_wrapped() -> bool: @overload def effectful( fn: None = ..., type: Optional[str] = ... -) -> Callable[[Callable[_P, _T]], Callable[..., Optional[_T]]]: - ... +) -> Callable[[Callable[_P, _T]], Callable[..., Optional[_T]]]: ... @overload def effectful( fn: Callable[_P, _T] = ..., type: Optional[str] = ... -) -> Callable[..., Optional[_T]]: - ... +) -> Callable[..., Optional[_T]]: ... def effectful( diff --git a/pyro/poutine/trace_struct.py b/pyro/poutine/trace_struct.py index 7b7e286747..ccfd2a78c7 100644 --- a/pyro/poutine/trace_struct.py +++ b/pyro/poutine/trace_struct.py @@ -559,9 +559,11 @@ def _format_table(rows: List[List[Optional[str]]]) -> str: else: cols[j].append(cell) cols = [ - [""] * (width - len(col)) + col - if direction == "r" - else col + [""] * (width - len(col)) + ( + [""] * (width - len(col)) + col + if direction == "r" + else col + [""] * (width - len(col)) + ) for width, col, direction in zip(column_widths, cols, "rrl") ] justified_rows.append(sum(cols, [])) diff --git a/pyro/primitives.py b/pyro/primitives.py index 320a7ca822..577450b1d6 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -493,9 +493,9 @@ def module( mod_name = _name if _name in target_state_dict.keys(): if not is_param: - deep_getattr(nn_module, mod_name)._parameters[ - param_name - ] = target_state_dict[_name] + deep_getattr(nn_module, mod_name)._parameters[param_name] = ( + target_state_dict[_name] + ) else: nn_module._parameters[mod_name] = target_state_dict[_name] # type: ignore[assignment] diff --git a/tests/contrib/funsor/test_tmc.py b/tests/contrib/funsor/test_tmc.py index 26c45c7b1c..81da5a653e 100644 --- a/tests/contrib/funsor/test_tmc.py +++ b/tests/contrib/funsor/test_tmc.py @@ -170,9 +170,11 @@ def nonfactorized_guide(reparameterized): guide = ( factorized_guide if guide_type == "factorized" - else nonfactorized_guide - if guide_type == "nonfactorized" - else lambda *args: None + else ( + nonfactorized_guide + if guide_type == "nonfactorized" + else lambda *args: None + ) ) tmc_guide = infer.config_enumerate( guide, diff --git a/tests/infer/test_smcfilter.py b/tests/infer/test_smcfilter.py index 0b095275ee..62431a17fe 100644 --- a/tests/infer/test_smcfilter.py +++ b/tests/infer/test_smcfilter.py @@ -267,8 +267,6 @@ def step(self, state, datum): smc.step(datum) expected = hmm.filter(data[: 1 + t]) actual = smc.get_empirical()["z"] - assert_close( - actual.variance**0.5, expected.variance**0.5, atol=0.1, rtol=0.5 - ) + assert_close(actual.variance**0.5, expected.variance**0.5, atol=0.1, rtol=0.5) sigma = actual.variance.max().item() ** 0.5 assert_close(actual.mean, expected.mean, atol=3 * sigma) diff --git a/tests/infer/test_tmc.py b/tests/infer/test_tmc.py index 0c29792704..eddf64f9a0 100644 --- a/tests/infer/test_tmc.py +++ b/tests/infer/test_tmc.py @@ -139,10 +139,13 @@ def nonfactorized_guide(reparameterized): guide = ( factorized_guide if guide_type == "factorized" - else nonfactorized_guide - if guide_type == "nonfactorized" - else poutine.block( - model, hide_fn=lambda msg: msg["type"] == "sample" and msg["is_observed"] + else ( + nonfactorized_guide + if guide_type == "nonfactorized" + else poutine.block( + model, + hide_fn=lambda msg: msg["type"] == "sample" and msg["is_observed"], + ) ) ) flat_num_samples = num_samples ** min(depth, 2) # don't use too many, expensive @@ -255,9 +258,9 @@ def nonfactorized_guide(reparameterized): guide = ( factorized_guide if guide_type == "factorized" - else nonfactorized_guide - if guide_type == "nonfactorized" - else lambda *args: None + else ( + nonfactorized_guide if guide_type == "nonfactorized" else lambda *args: None + ) ) tmc_guide = config_enumerate( guide, diff --git a/tests/ops/test_welford.py b/tests/ops/test_welford.py index d345c7a75f..b0eadc8166 100644 --- a/tests/ops/test_welford.py +++ b/tests/ops/test_welford.py @@ -10,6 +10,7 @@ from tests.common import assert_equal +@pytest.mark.filterwarnings("ignore:.*degrees of freedom is <= 0") @pytest.mark.parametrize("n_samples,dim_size", [(1000, 1), (1000, 7), (1, 1)]) @pytest.mark.init(rng_seed=7) def test_welford_diagonal(n_samples, dim_size): diff --git a/tests/perf/test_benchmark.py b/tests/perf/test_benchmark.py index 280a6f2963..3eeb7250a2 100644 --- a/tests/perf/test_benchmark.py +++ b/tests/perf/test_benchmark.py @@ -20,6 +20,8 @@ from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.nuts import NUTS +pytestmark = pytest.mark.stage("benchmark") + Model = namedtuple("TestModel", ["model", "model_args", "model_id"])