Skip to content

Commit

Permalink
Get CI working with new black, torch, pytest (#3318)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored Feb 2, 2024
1 parent 1a11185 commit df34383
Show file tree
Hide file tree
Showing 50 changed files with 208 additions and 183 deletions.
8 changes: 5 additions & 3 deletions examples/capture_recapture/cjs.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,11 @@ def expose_fn(msg):
elbo = TraceTMC_ELBO(max_plate_nesting=1)
tmc_model = poutine.infer_config(
model,
lambda msg: {"num_samples": args.tmc_num_samples, "expand": False}
if msg["infer"].get("enumerate", None) == "parallel"
else {},
lambda msg: (
{"num_samples": args.tmc_num_samples, "expand": False}
if msg["infer"].get("enumerate", None) == "parallel"
else {}
),
) # noqa: E501
svi = SVI(tmc_model, guide, optim, elbo)
else:
Expand Down
8 changes: 5 additions & 3 deletions examples/contrib/funsor/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,9 +737,11 @@ def main(args):
elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2)
tmc_model = handlers.infer_config(
model,
lambda msg: {"num_samples": args.tmc_num_samples, "expand": False}
if msg["infer"].get("enumerate", None) == "parallel"
else {},
lambda msg: (
{"num_samples": args.tmc_num_samples, "expand": False}
if msg["infer"].get("enumerate", None) == "parallel"
else {}
),
) # noqa: E501
svi = infer.SVI(tmc_model, guide, optimizer, elbo)
else:
Expand Down
8 changes: 5 additions & 3 deletions examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,9 +677,11 @@ def main(args):
elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2)
tmc_model = poutine.infer_config(
model,
lambda msg: {"num_samples": args.tmc_num_samples, "expand": False}
if msg["infer"].get("enumerate", None) == "parallel"
else {},
lambda msg: (
{"num_samples": args.tmc_num_samples, "expand": False}
if msg["infer"].get("enumerate", None) == "parallel"
else {}
),
) # noqa: E501
svi = SVI(tmc_model, guide, optim, elbo)
else:
Expand Down
8 changes: 2 additions & 6 deletions examples/mixed_hmm/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,12 @@ def closure():
re_str = "g" + (
"n"
if args["group"] is None
else "d"
if args["group"] == "discrete"
else "c"
else "d" if args["group"] == "discrete" else "c"
)
re_str += "i" + (
"n"
if args["individual"] is None
else "d"
if args["individual"] == "discrete"
else "c"
else "d" if args["individual"] == "discrete" else "c"
)
results_filename = "expt_{}_{}_{}.json".format(
dataset, re_str, str(uuid.uuid4().hex)[0:5]
Expand Down
3 changes: 1 addition & 2 deletions pyro/contrib/autoname/autoname.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ def _pyro_genname(msg):


@_make_handler(AutonameMessenger, __name__)
def autoname(fn=None, name=None):
...
def autoname(fn=None, name=None): ...


@singledispatch
Expand Down
1 change: 1 addition & 0 deletions pyro/contrib/bnn/hidden_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class HiddenLayer(TorchDistribution):
"Variational dropout and the local reparameterization trick."
Advances in Neural Information Processing Systems. 2015.
"""

has_rsample = True

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/cevae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __init__(self, data):
super().__init__()
with torch.no_grad():
loc = data.mean(0)
scale = data.std(0)
scale = data.std(0, unbiased=False)
scale[~(scale > 0)] = 1.0
self.register_buffer("loc", loc)
self.register_buffer("inv_scale", scale.reciprocal())
Expand Down
8 changes: 5 additions & 3 deletions pyro/contrib/epidemiology/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,9 +680,11 @@ def transition(self, params, state, t):
coal_rate = R * (1.0 + 1.0 / k) / (tau_i * state["I"] + 1e-8)
pyro.factor(
"coalescent_{}".format(t),
self.coal_likelihood(coal_rate, t)
if t_is_observed
else torch.tensor(0.0),
(
self.coal_likelihood(coal_rate, t)
if t_is_observed
else torch.tensor(0.0)
),
)

# Update compartements with flows.
Expand Down
21 changes: 7 additions & 14 deletions pyro/contrib/funsor/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,15 @@


@_make_handler(EnumMessenger, __name__)
def enum(fn=None, first_available_dim=None):
...
def enum(fn=None, first_available_dim=None): ...


@_make_handler(MarkovMessenger, __name__)
def markov(fn=None, history=1, keep=False):
...
def markov(fn=None, history=1, keep=False): ...


@_make_handler(NamedMessenger, __name__)
def named(fn=None, first_available_dim=None):
...
def named(fn=None, first_available_dim=None): ...


@_make_handler(PlateMessenger, __name__)
Expand All @@ -47,20 +44,16 @@ def plate(
dim=None,
use_cuda=None,
device=None,
):
...
): ...


@_make_handler(ReplayMessenger, __name__)
def replay(fn=None, trace=None, params=None):
...
def replay(fn=None, trace=None, params=None): ...


@_make_handler(TraceMessenger, __name__)
def trace(fn=None, graph_type=None, param_only=None, pack_online=True):
...
def trace(fn=None, graph_type=None, param_only=None, pack_online=True): ...


@_make_handler(VectorizedMarkovMessenger, __name__)
def vectorized_markov(fn=None, name=None, size=None, dim=None, history=1):
...
def vectorized_markov(fn=None, name=None, size=None, dim=None, history=1): ...
6 changes: 3 additions & 3 deletions pyro/contrib/funsor/handlers/replay_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ class ReplayMessenger(OrigReplayMessenger):

def _pyro_sample(self, msg):
name = msg["name"]
msg[
"replay_active"
] = True # indicate replaying so importance weights can be scaled
msg["replay_active"] = (
True # indicate replaying so importance weights can be scaled
)
if self.trace is None:
return

Expand Down
8 changes: 5 additions & 3 deletions pyro/contrib/funsor/infer/trace_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
@copy_docs_from(_OrigTrace_ELBO)
class Trace_ELBO(ELBO):
def differentiable_loss(self, model, guide, *args, **kwargs):
with enum(), plate(
size=self.num_particles
) if self.num_particles > 1 else contextlib.ExitStack():
with enum(), (
plate(size=self.num_particles)
if self.num_particles > 1
else contextlib.ExitStack()
):
guide_tr = trace(config_enumerate(default="flat")(guide)).get_trace(
*args, **kwargs
)
Expand Down
28 changes: 16 additions & 12 deletions pyro/contrib/funsor/infer/traceenum_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,14 @@ def terms_from_trace(tr):
class TraceMarkovEnum_ELBO(ELBO):
def differentiable_loss(self, model, guide, *args, **kwargs):
# get batched, enumerated, to_funsor-ed traces from the guide and model
with plate(
size=self.num_particles
) if self.num_particles > 1 else contextlib.ExitStack(), enum(
first_available_dim=(-self.max_plate_nesting - 1)
if self.max_plate_nesting
else None
with (
plate(size=self.num_particles)
if self.num_particles > 1
else contextlib.ExitStack()
), enum(
first_available_dim=(
(-self.max_plate_nesting - 1) if self.max_plate_nesting else None
)
):
guide_tr = trace(guide).get_trace(*args, **kwargs)
model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs)
Expand Down Expand Up @@ -170,12 +172,14 @@ def differentiable_loss(self, model, guide, *args, **kwargs):
class TraceEnum_ELBO(ELBO):
def differentiable_loss(self, model, guide, *args, **kwargs):
# get batched, enumerated, to_funsor-ed traces from the guide and model
with plate(
size=self.num_particles
) if self.num_particles > 1 else contextlib.ExitStack(), enum(
first_available_dim=(-self.max_plate_nesting - 1)
if self.max_plate_nesting
else None
with (
plate(size=self.num_particles)
if self.num_particles > 1
else contextlib.ExitStack()
), enum(
first_available_dim=(
(-self.max_plate_nesting - 1) if self.max_plate_nesting else None
)
):
guide_tr = trace(guide).get_trace(*args, **kwargs)
model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs)
Expand Down
14 changes: 8 additions & 6 deletions pyro/contrib/funsor/infer/tracetmc_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
@copy_docs_from(_OrigTraceTMC_ELBO)
class TraceTMC_ELBO(ELBO):
def differentiable_loss(self, model, guide, *args, **kwargs):
with plate(
size=self.num_particles
) if self.num_particles > 1 else contextlib.ExitStack(), enum(
first_available_dim=(-self.max_plate_nesting - 1)
if self.max_plate_nesting
else None
with (
plate(size=self.num_particles)
if self.num_particles > 1
else contextlib.ExitStack()
), enum(
first_available_dim=(
(-self.max_plate_nesting - 1) if self.max_plate_nesting else None
)
):
guide_tr = trace(guide).get_trace(*args, **kwargs)
model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs)
Expand Down
30 changes: 15 additions & 15 deletions pyro/contrib/timeseries/lgssmgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def _get_init_dist(self):
covar[: self.full_gp_state_dim, : self.full_gp_state_dim] = block_diag_embed(
self.kernel.stationary_covariance()
)
covar[
self.full_gp_state_dim :, self.full_gp_state_dim :
] = self.init_noise_scale_sq.diag_embed()
covar[self.full_gp_state_dim :, self.full_gp_state_dim :] = (
self.init_noise_scale_sq.diag_embed()
)
return MultivariateNormal(loc, covar)

def _get_obs_dist(self):
Expand All @@ -134,23 +134,23 @@ def get_dist(self, duration=None):
trans_covar = self.z_trans_matrix.new_zeros(
self.full_state_dim, self.full_state_dim
)
trans_covar[
: self.full_gp_state_dim, : self.full_gp_state_dim
] = block_diag_embed(gp_process_covar)
trans_covar[
self.full_gp_state_dim :, self.full_gp_state_dim :
] = self.trans_noise_scale_sq.diag_embed()
trans_covar[: self.full_gp_state_dim, : self.full_gp_state_dim] = (
block_diag_embed(gp_process_covar)
)
trans_covar[self.full_gp_state_dim :, self.full_gp_state_dim :] = (
self.trans_noise_scale_sq.diag_embed()
)
trans_dist = MultivariateNormal(
trans_covar.new_zeros(self.full_state_dim), trans_covar
)

full_trans_mat = trans_covar.new_zeros(self.full_state_dim, self.full_state_dim)
full_trans_mat[
: self.full_gp_state_dim, : self.full_gp_state_dim
] = block_diag_embed(gp_trans_matrix)
full_trans_mat[
self.full_gp_state_dim :, self.full_gp_state_dim :
] = self.z_trans_matrix
full_trans_mat[: self.full_gp_state_dim, : self.full_gp_state_dim] = (
block_diag_embed(gp_trans_matrix)
)
full_trans_mat[self.full_gp_state_dim :, self.full_gp_state_dim :] = (
self.z_trans_matrix
)

return dist.GaussianHMM(
self._get_init_dist(),
Expand Down
1 change: 1 addition & 0 deletions pyro/contrib/tracking/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class EKFDistribution(TorchDistribution):
:param dt: time step
:type dt: torch.Tensor
"""

arg_constraints = {
"measurement_cov": constraints.positive_definite,
"P0": constraints.positive_definite,
Expand Down
4 changes: 3 additions & 1 deletion pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
LinearHMM,
)
from pyro.distributions.improper_uniform import ImproperUniform
from pyro.distributions.inverse_gamma import InverseGamma

if "InverseGamma" not in locals(): # Use PyTorch version if available.
from pyro.distributions.inverse_gamma import InverseGamma
from pyro.distributions.lkj import LKJ, LKJCorrCholesky
from pyro.distributions.log_normal_negative_binomial import LogNormalNegativeBinomial
from pyro.distributions.logistic import Logistic, SkewLogistic
Expand Down
4 changes: 1 addition & 3 deletions pyro/distributions/asymmetriclaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,7 @@ def variance(self):
total = left + right
p = left / total
q = right / total
return (
p * left**2 + q * right**2 + p * q * total**2 + self.soft_scale**2
)
return p * left**2 + q * right**2 + p * q * total**2 + self.soft_scale**2


def _logerfc(x):
Expand Down
16 changes: 10 additions & 6 deletions pyro/distributions/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,11 @@ class ConditionalComposeTransformModule(

def __init__(self, transforms, cache_size: int = 0):
self.transforms = [
ConstantConditionalTransform(t)
if not isinstance(t, ConditionalTransform)
else t
(
ConstantConditionalTransform(t)
if not isinstance(t, ConditionalTransform)
else t
)
for t in transforms
]
super().__init__()
Expand Down Expand Up @@ -131,9 +133,11 @@ def __init__(self, base_dist, transforms):
else ConstantConditionalDistribution(base_dist)
)
self.transforms = [
t
if isinstance(t, ConditionalTransform)
else ConstantConditionalTransform(t)
(
t
if isinstance(t, ConditionalTransform)
else ConstantConditionalTransform(t)
)
for t in transforms
]

Expand Down
2 changes: 2 additions & 0 deletions pyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class BetaBinomial(TorchDistribution):
:param total_count: Number of Bernoulli trials.
:type total_count: float or torch.Tensor
"""

arg_constraints = {
"concentration1": constraints.positive,
"concentration0": constraints.positive,
Expand Down Expand Up @@ -150,6 +151,7 @@ class DirichletMultinomial(TorchDistribution):
:param bool is_sparse: Whether to assume value is mostly zero when computing
:meth:`log_prob`, which can speed up computation when data is sparse.
"""

arg_constraints = {
"concentration": constraints.independent(constraints.positive, 1),
"total_count": constraints.nonnegative_integer,
Expand Down
14 changes: 8 additions & 6 deletions pyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,14 @@ def check(self, value):
{}
""".format(
_name,
"alias of :class:`torch.distributions.constraints.{}`".format(_name)
if globals()[_name].__module__.startswith("torch")
else ".. autoclass:: {}".format(
_name
if type(globals()[_name]) is type
else type(globals()[_name]).__name__
(
"alias of :class:`torch.distributions.constraints.{}`".format(_name)
if globals()[_name].__module__.startswith("torch")
else ".. autoclass:: {}".format(
_name
if type(globals()[_name]) is type
else type(globals()[_name]).__name__
)
),
)
for _name in sorted(__all__)
Expand Down
1 change: 1 addition & 0 deletions pyro/distributions/grouped_normal_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class GroupedNormalNormal(TorchDistribution):
:param torch.LongTensor group_idx: Tensor of indices of shape `(num_data,)` linking each observation to one
of the `num_groups` groups that are specified in `prior_loc` and `prior_scale`.
"""

arg_constraints = {
"prior_loc": constraints.real,
"prior_scale": constraints.positive,
Expand Down
Loading

0 comments on commit df34383

Please sign in to comment.