Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get CI working with new black, torch, pytest #3318

Merged
merged 9 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/capture_recapture/cjs.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,11 @@ def expose_fn(msg):
elbo = TraceTMC_ELBO(max_plate_nesting=1)
tmc_model = poutine.infer_config(
model,
lambda msg: {"num_samples": args.tmc_num_samples, "expand": False}
if msg["infer"].get("enumerate", None) == "parallel"
else {},
lambda msg: (
{"num_samples": args.tmc_num_samples, "expand": False}
if msg["infer"].get("enumerate", None) == "parallel"
else {}
),
) # noqa: E501
svi = SVI(tmc_model, guide, optim, elbo)
else:
Expand Down
8 changes: 5 additions & 3 deletions examples/contrib/funsor/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,9 +737,11 @@ def main(args):
elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2)
tmc_model = handlers.infer_config(
model,
lambda msg: {"num_samples": args.tmc_num_samples, "expand": False}
if msg["infer"].get("enumerate", None) == "parallel"
else {},
lambda msg: (
{"num_samples": args.tmc_num_samples, "expand": False}
if msg["infer"].get("enumerate", None) == "parallel"
else {}
),
) # noqa: E501
svi = infer.SVI(tmc_model, guide, optimizer, elbo)
else:
Expand Down
8 changes: 5 additions & 3 deletions examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,9 +677,11 @@ def main(args):
elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2)
tmc_model = poutine.infer_config(
model,
lambda msg: {"num_samples": args.tmc_num_samples, "expand": False}
if msg["infer"].get("enumerate", None) == "parallel"
else {},
lambda msg: (
{"num_samples": args.tmc_num_samples, "expand": False}
if msg["infer"].get("enumerate", None) == "parallel"
else {}
),
) # noqa: E501
svi = SVI(tmc_model, guide, optim, elbo)
else:
Expand Down
8 changes: 2 additions & 6 deletions examples/mixed_hmm/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,12 @@ def closure():
re_str = "g" + (
"n"
if args["group"] is None
else "d"
if args["group"] == "discrete"
else "c"
else "d" if args["group"] == "discrete" else "c"
)
re_str += "i" + (
"n"
if args["individual"] is None
else "d"
if args["individual"] == "discrete"
else "c"
else "d" if args["individual"] == "discrete" else "c"
)
results_filename = "expt_{}_{}_{}.json".format(
dataset, re_str, str(uuid.uuid4().hex)[0:5]
Expand Down
3 changes: 1 addition & 2 deletions pyro/contrib/autoname/autoname.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ def _pyro_genname(msg):


@_make_handler(AutonameMessenger, __name__)
def autoname(fn=None, name=None):
...
def autoname(fn=None, name=None): ...


@singledispatch
Expand Down
1 change: 1 addition & 0 deletions pyro/contrib/bnn/hidden_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class HiddenLayer(TorchDistribution):
"Variational dropout and the local reparameterization trick."
Advances in Neural Information Processing Systems. 2015.
"""

has_rsample = True

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/cevae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __init__(self, data):
super().__init__()
with torch.no_grad():
loc = data.mean(0)
scale = data.std(0)
scale = data.std(0, unbiased=False)
scale[~(scale > 0)] = 1.0
self.register_buffer("loc", loc)
self.register_buffer("inv_scale", scale.reciprocal())
Expand Down
8 changes: 5 additions & 3 deletions pyro/contrib/epidemiology/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,9 +680,11 @@ def transition(self, params, state, t):
coal_rate = R * (1.0 + 1.0 / k) / (tau_i * state["I"] + 1e-8)
pyro.factor(
"coalescent_{}".format(t),
self.coal_likelihood(coal_rate, t)
if t_is_observed
else torch.tensor(0.0),
(
self.coal_likelihood(coal_rate, t)
if t_is_observed
else torch.tensor(0.0)
),
)

# Update compartements with flows.
Expand Down
21 changes: 7 additions & 14 deletions pyro/contrib/funsor/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,15 @@


@_make_handler(EnumMessenger, __name__)
def enum(fn=None, first_available_dim=None):
...
def enum(fn=None, first_available_dim=None): ...


@_make_handler(MarkovMessenger, __name__)
def markov(fn=None, history=1, keep=False):
...
def markov(fn=None, history=1, keep=False): ...


@_make_handler(NamedMessenger, __name__)
def named(fn=None, first_available_dim=None):
...
def named(fn=None, first_available_dim=None): ...


@_make_handler(PlateMessenger, __name__)
Expand All @@ -47,20 +44,16 @@ def plate(
dim=None,
use_cuda=None,
device=None,
):
...
): ...


@_make_handler(ReplayMessenger, __name__)
def replay(fn=None, trace=None, params=None):
...
def replay(fn=None, trace=None, params=None): ...


@_make_handler(TraceMessenger, __name__)
def trace(fn=None, graph_type=None, param_only=None, pack_online=True):
...
def trace(fn=None, graph_type=None, param_only=None, pack_online=True): ...


@_make_handler(VectorizedMarkovMessenger, __name__)
def vectorized_markov(fn=None, name=None, size=None, dim=None, history=1):
...
def vectorized_markov(fn=None, name=None, size=None, dim=None, history=1): ...
6 changes: 3 additions & 3 deletions pyro/contrib/funsor/handlers/replay_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ class ReplayMessenger(OrigReplayMessenger):

def _pyro_sample(self, msg):
name = msg["name"]
msg[
"replay_active"
] = True # indicate replaying so importance weights can be scaled
msg["replay_active"] = (
True # indicate replaying so importance weights can be scaled
)
if self.trace is None:
return

Expand Down
8 changes: 5 additions & 3 deletions pyro/contrib/funsor/infer/trace_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
@copy_docs_from(_OrigTrace_ELBO)
class Trace_ELBO(ELBO):
def differentiable_loss(self, model, guide, *args, **kwargs):
with enum(), plate(
size=self.num_particles
) if self.num_particles > 1 else contextlib.ExitStack():
with enum(), (
plate(size=self.num_particles)
if self.num_particles > 1
else contextlib.ExitStack()
):
guide_tr = trace(config_enumerate(default="flat")(guide)).get_trace(
*args, **kwargs
)
Expand Down
28 changes: 16 additions & 12 deletions pyro/contrib/funsor/infer/traceenum_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,14 @@ def terms_from_trace(tr):
class TraceMarkovEnum_ELBO(ELBO):
def differentiable_loss(self, model, guide, *args, **kwargs):
# get batched, enumerated, to_funsor-ed traces from the guide and model
with plate(
size=self.num_particles
) if self.num_particles > 1 else contextlib.ExitStack(), enum(
first_available_dim=(-self.max_plate_nesting - 1)
if self.max_plate_nesting
else None
with (
plate(size=self.num_particles)
if self.num_particles > 1
else contextlib.ExitStack()
), enum(
first_available_dim=(
(-self.max_plate_nesting - 1) if self.max_plate_nesting else None
)
):
guide_tr = trace(guide).get_trace(*args, **kwargs)
model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs)
Expand Down Expand Up @@ -170,12 +172,14 @@ def differentiable_loss(self, model, guide, *args, **kwargs):
class TraceEnum_ELBO(ELBO):
def differentiable_loss(self, model, guide, *args, **kwargs):
# get batched, enumerated, to_funsor-ed traces from the guide and model
with plate(
size=self.num_particles
) if self.num_particles > 1 else contextlib.ExitStack(), enum(
first_available_dim=(-self.max_plate_nesting - 1)
if self.max_plate_nesting
else None
with (
plate(size=self.num_particles)
if self.num_particles > 1
else contextlib.ExitStack()
), enum(
first_available_dim=(
(-self.max_plate_nesting - 1) if self.max_plate_nesting else None
)
):
guide_tr = trace(guide).get_trace(*args, **kwargs)
model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs)
Expand Down
14 changes: 8 additions & 6 deletions pyro/contrib/funsor/infer/tracetmc_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
@copy_docs_from(_OrigTraceTMC_ELBO)
class TraceTMC_ELBO(ELBO):
def differentiable_loss(self, model, guide, *args, **kwargs):
with plate(
size=self.num_particles
) if self.num_particles > 1 else contextlib.ExitStack(), enum(
first_available_dim=(-self.max_plate_nesting - 1)
if self.max_plate_nesting
else None
with (
plate(size=self.num_particles)
if self.num_particles > 1
else contextlib.ExitStack()
), enum(
first_available_dim=(
(-self.max_plate_nesting - 1) if self.max_plate_nesting else None
)
):
guide_tr = trace(guide).get_trace(*args, **kwargs)
model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs)
Expand Down
30 changes: 15 additions & 15 deletions pyro/contrib/timeseries/lgssmgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def _get_init_dist(self):
covar[: self.full_gp_state_dim, : self.full_gp_state_dim] = block_diag_embed(
self.kernel.stationary_covariance()
)
covar[
self.full_gp_state_dim :, self.full_gp_state_dim :
] = self.init_noise_scale_sq.diag_embed()
covar[self.full_gp_state_dim :, self.full_gp_state_dim :] = (
self.init_noise_scale_sq.diag_embed()
)
return MultivariateNormal(loc, covar)

def _get_obs_dist(self):
Expand All @@ -134,23 +134,23 @@ def get_dist(self, duration=None):
trans_covar = self.z_trans_matrix.new_zeros(
self.full_state_dim, self.full_state_dim
)
trans_covar[
: self.full_gp_state_dim, : self.full_gp_state_dim
] = block_diag_embed(gp_process_covar)
trans_covar[
self.full_gp_state_dim :, self.full_gp_state_dim :
] = self.trans_noise_scale_sq.diag_embed()
trans_covar[: self.full_gp_state_dim, : self.full_gp_state_dim] = (
block_diag_embed(gp_process_covar)
)
trans_covar[self.full_gp_state_dim :, self.full_gp_state_dim :] = (
self.trans_noise_scale_sq.diag_embed()
)
trans_dist = MultivariateNormal(
trans_covar.new_zeros(self.full_state_dim), trans_covar
)

full_trans_mat = trans_covar.new_zeros(self.full_state_dim, self.full_state_dim)
full_trans_mat[
: self.full_gp_state_dim, : self.full_gp_state_dim
] = block_diag_embed(gp_trans_matrix)
full_trans_mat[
self.full_gp_state_dim :, self.full_gp_state_dim :
] = self.z_trans_matrix
full_trans_mat[: self.full_gp_state_dim, : self.full_gp_state_dim] = (
block_diag_embed(gp_trans_matrix)
)
full_trans_mat[self.full_gp_state_dim :, self.full_gp_state_dim :] = (
self.z_trans_matrix
)

return dist.GaussianHMM(
self._get_init_dist(),
Expand Down
1 change: 1 addition & 0 deletions pyro/contrib/tracking/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class EKFDistribution(TorchDistribution):
:param dt: time step
:type dt: torch.Tensor
"""

arg_constraints = {
"measurement_cov": constraints.positive_definite,
"P0": constraints.positive_definite,
Expand Down
4 changes: 3 additions & 1 deletion pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
LinearHMM,
)
from pyro.distributions.improper_uniform import ImproperUniform
from pyro.distributions.inverse_gamma import InverseGamma

if "InverseGamma" not in locals(): # Use PyTorch version if available.
from pyro.distributions.inverse_gamma import InverseGamma
from pyro.distributions.lkj import LKJ, LKJCorrCholesky
from pyro.distributions.log_normal_negative_binomial import LogNormalNegativeBinomial
from pyro.distributions.logistic import Logistic, SkewLogistic
Expand Down
4 changes: 1 addition & 3 deletions pyro/distributions/asymmetriclaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,7 @@ def variance(self):
total = left + right
p = left / total
q = right / total
return (
p * left**2 + q * right**2 + p * q * total**2 + self.soft_scale**2
)
return p * left**2 + q * right**2 + p * q * total**2 + self.soft_scale**2


def _logerfc(x):
Expand Down
16 changes: 10 additions & 6 deletions pyro/distributions/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,11 @@ class ConditionalComposeTransformModule(

def __init__(self, transforms, cache_size: int = 0):
self.transforms = [
ConstantConditionalTransform(t)
if not isinstance(t, ConditionalTransform)
else t
(
ConstantConditionalTransform(t)
if not isinstance(t, ConditionalTransform)
else t
)
for t in transforms
]
super().__init__()
Expand Down Expand Up @@ -131,9 +133,11 @@ def __init__(self, base_dist, transforms):
else ConstantConditionalDistribution(base_dist)
)
self.transforms = [
t
if isinstance(t, ConditionalTransform)
else ConstantConditionalTransform(t)
(
t
if isinstance(t, ConditionalTransform)
else ConstantConditionalTransform(t)
)
for t in transforms
]

Expand Down
2 changes: 2 additions & 0 deletions pyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class BetaBinomial(TorchDistribution):
:param total_count: Number of Bernoulli trials.
:type total_count: float or torch.Tensor
"""

arg_constraints = {
"concentration1": constraints.positive,
"concentration0": constraints.positive,
Expand Down Expand Up @@ -150,6 +151,7 @@ class DirichletMultinomial(TorchDistribution):
:param bool is_sparse: Whether to assume value is mostly zero when computing
:meth:`log_prob`, which can speed up computation when data is sparse.
"""

arg_constraints = {
"concentration": constraints.independent(constraints.positive, 1),
"total_count": constraints.nonnegative_integer,
Expand Down
14 changes: 8 additions & 6 deletions pyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,14 @@ def check(self, value):
{}
""".format(
_name,
"alias of :class:`torch.distributions.constraints.{}`".format(_name)
if globals()[_name].__module__.startswith("torch")
else ".. autoclass:: {}".format(
_name
if type(globals()[_name]) is type
else type(globals()[_name]).__name__
(
"alias of :class:`torch.distributions.constraints.{}`".format(_name)
if globals()[_name].__module__.startswith("torch")
else ".. autoclass:: {}".format(
_name
if type(globals()[_name]) is type
else type(globals()[_name]).__name__
)
),
)
for _name in sorted(__all__)
Expand Down
1 change: 1 addition & 0 deletions pyro/distributions/grouped_normal_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class GroupedNormalNormal(TorchDistribution):
:param torch.LongTensor group_idx: Tensor of indices of shape `(num_data,)` linking each observation to one
of the `num_groups` groups that are specified in `prior_loc` and `prior_scale`.
"""

arg_constraints = {
"prior_loc": constraints.real,
"prior_scale": constraints.positive,
Expand Down
Loading
Loading