Skip to content

Commit

Permalink
rerun newest black on repo (#3178)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjankowiak authored Feb 3, 2023
1 parent 8b6d331 commit 685c7ad
Show file tree
Hide file tree
Showing 71 changed files with 5 additions and 139 deletions.
4 changes: 0 additions & 4 deletions examples/air/air.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(
likelihood_sd=0.3,
use_cuda=False,
):

super().__init__()

self.num_steps = num_steps
Expand Down Expand Up @@ -127,7 +126,6 @@ def __init__(
self.cuda()

def prior(self, n, **kwargs):

state = ModelState(
x=torch.zeros(n, self.x_size, self.x_size, **self.options),
z_pres=torch.ones(n, self.z_pres_size, **self.options),
Expand All @@ -145,7 +143,6 @@ def prior(self, n, **kwargs):
return (z_where, z_pres), state.x

def prior_step(self, t, n, prev, z_pres_prior_p=default_z_pres_prior_p):

# Sample presence indicators.
z_pres = pyro.sample(
"z_pres_{}".format(t),
Expand Down Expand Up @@ -263,7 +260,6 @@ def guide(self, data, batch_size, **kwargs):
return z_where, z_pres

def guide_step(self, t, n, prev, inputs):

rnn_input = torch.cat(
(inputs["embed"], prev.z_where, prev.z_what, prev.z_pres), 1
)
Expand Down
2 changes: 0 additions & 2 deletions examples/air/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def load_data():


def main(**kwargs):

args = argparse.Namespace(**kwargs)

if "save" in args:
Expand Down Expand Up @@ -229,7 +228,6 @@ def per_param_optim_args(param_name):
examples_to_viz = X[5:10]

for i in range(1, args.num_steps + 1):

loss = svi.step(
X, batch_size=args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i)
)
Expand Down
2 changes: 1 addition & 1 deletion examples/air/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
layers = []
in_sizes = [in_size] + out_sizes[0:-1]
sizes = list(zip(in_sizes, out_sizes))
for (i, o) in sizes[0:-1]:
for i, o in sizes[0:-1]:
layers.append(nn.Linear(i, o))
layers.append(non_linear_layer())
layers.append(nn.Linear(sizes[-1][0], sizes[-1][1]))
Expand Down
1 change: 1 addition & 0 deletions examples/contrib/funsor/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def model_0(sequences, lengths, args, batch_size=None, include_prior=True):
# and randomly subsample data to size batch_size. To add jit support we
# silence some warnings and try to avoid dynamic program structure.


# Note that this is the "HMM" model in reference [1] (with the difference that
# in [1] the probabilities probs_x and probs_y are not MAP-regularized with
# Dirichlet and Beta distributions for any of the models)
Expand Down
1 change: 0 additions & 1 deletion examples/contrib/mue/FactorMuE.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def generate_data(small_test, include_stop, device):


def main(args):

# Load dataset.
if args.cpu_data or not args.cuda:
device = torch.device("cpu")
Expand Down
1 change: 0 additions & 1 deletion examples/contrib/mue/ProfileHMM.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def generate_data(small_test, include_stop, device):


def main(args):

pyro.set_rng_seed(args.rng_seed)

# Load dataset.
Expand Down
1 change: 0 additions & 1 deletion examples/contrib/oed/ab_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def true_ape(ns):


def main(num_vi_steps, num_bo_steps, seed):

pyro.set_rng_seed(seed)
pyro.clear_param_store()

Expand Down
1 change: 0 additions & 1 deletion examples/cvae/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def train(
early_stop_patience,
model_path,
):

# Train baseline
baseline_net = BaselineNet(500, 500)
baseline_net.to(device)
Expand Down
2 changes: 0 additions & 2 deletions examples/cvae/cvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def model(self, xs, ys=None):
pyro.module("generation_net", self)
batch_size = xs.shape[0]
with pyro.plate("data"):

# Prior network uses the baseline predictions as initial guess.
# This is the generative process with recurrent connection
with torch.no_grad():
Expand Down Expand Up @@ -130,7 +129,6 @@ def train(
model_path,
pre_trained_baseline_net,
):

# clear param store
pyro.clear_param_store()

Expand Down
2 changes: 0 additions & 2 deletions examples/cvae/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def visualize(
num_samples,
image_path=None,
):

# Load sample random data
datasets, _, dataset_sizes = get_data(
num_quadrant_inputs=num_quadrant_inputs, batch_size=num_images
Expand Down Expand Up @@ -121,7 +120,6 @@ def generate_table(
num_particles,
col_name,
):

# Load sample random data
datasets, dataloaders, dataset_sizes = get_data(
num_quadrant_inputs=num_quadrant_inputs, batch_size=32
Expand Down
2 changes: 0 additions & 2 deletions examples/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def model(
mini_batch_seq_lengths,
annealing_factor=1.0,
):

# this is the number of time steps we need to process in the mini-batch
T_max = mini_batch.size(1)

Expand Down Expand Up @@ -269,7 +268,6 @@ def guide(
mini_batch_seq_lengths,
annealing_factor=1.0,
):

# this is the number of time steps we need to process in the mini-batch
T_max = mini_batch.size(1)
# register all PyTorch (sub)modules with pyro
Expand Down
4 changes: 0 additions & 4 deletions examples/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def jit_prob(equation, *operands, **kwargs):
"""
key = "prob", equation, kwargs["plates"]
if key not in _CACHE:

# This simply wraps einsum for jit compilation.
def _einsum(*operands):
return einsum(equation, *operands, **kwargs)
Expand All @@ -61,7 +60,6 @@ def jit_logprob(equation, *operands, **kwargs):
"""
key = "logprob", equation, kwargs["plates"]
if key not in _CACHE:

# This simply wraps einsum for jit compilation.
def _einsum(*operands):
return einsum(
Expand All @@ -81,7 +79,6 @@ def jit_gradient(equation, *operands, **kwargs):
"""
key = "gradient", equation, kwargs["plates"]
if key not in _CACHE:

# This wraps einsum for jit compilation, but we will call backward on the result.
def _einsum(*operands):
return einsum(
Expand Down Expand Up @@ -114,7 +111,6 @@ def _jit_adjoint(equation, *operands, **kwargs):
backend = kwargs.pop("backend", "pyro.ops.einsum.torch_sample")
key = backend, equation, tuple(x.shape for x in operands), kwargs["plates"]
if key not in _CACHE:

# This wraps a complete adjoint algorithm call.
@ignore_jit_warnings()
def _forward_backward(*operands):
Expand Down
1 change: 1 addition & 0 deletions examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def model_0(sequences, lengths, args, batch_size=None, include_prior=True):
# and randomly subsample data to size batch_size. To add jit support we
# silence some warnings and try to avoid dynamic program structure.


# Note that this is the "HMM" model in reference [1] (with the difference that
# in [1] the probabilities probs_x and probs_y are not MAP-regularized with
# Dirichlet and Beta distributions for any of the models)
Expand Down
3 changes: 0 additions & 3 deletions examples/mixed_hmm/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def _size(tensor):


def run_expt(args):

data_dir = args["folder"]
dataset = "seal" # args["dataset"]
seed = args["seed"]
Expand Down Expand Up @@ -79,7 +78,6 @@ def run_expt(args):
schedule_step_loss = True

for t in range(timesteps):

optimizer.zero_grad()
loss = loss_fn(model, guide)
loss.backward()
Expand Down Expand Up @@ -166,7 +164,6 @@ def closure():


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("-g", "--group", default="none", type=str)
parser.add_argument("-i", "--individual", default="none", type=str)
Expand Down
4 changes: 0 additions & 4 deletions examples/mixed_hmm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def guide_generic(config):

N_c = config["sizes"]["group"]
with pyro.plate("group", N_c, dim=-1):

if config["group"]["random"] == "continuous":
pyro.sample(
"eps_g",
Expand All @@ -59,7 +58,6 @@ def guide_generic(config):
with pyro.plate("individual", N_s, dim=-2), poutine.mask(
mask=config["individual"]["mask"]
):

# individual-level random effects
if config["individual"]["random"] == "continuous":
pyro.sample(
Expand Down Expand Up @@ -158,7 +156,6 @@ def model_generic(config):

N_c = config["sizes"]["group"]
with pyro.plate("group", N_c, dim=-1):

# group-level random effects
if config["group"]["random"] == "discrete":
# group-level discrete effect
Expand All @@ -179,7 +176,6 @@ def model_generic(config):
with pyro.plate("individual", N_s, dim=-2), poutine.mask(
mask=config["individual"]["mask"]
):

# individual-level random effects
if config["individual"]["random"] == "discrete":
# individual-level discrete effect
Expand Down
1 change: 0 additions & 1 deletion examples/mixed_hmm/seal_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def download_seal_data(filename):


def prepare_seal(filename, random_effects):

if not os.path.exists(filename):
download_seal_data(filename)

Expand Down
1 change: 0 additions & 1 deletion examples/rsa/hyperbole.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def test_truth():


def main(args):

# test_truth()

pragmatic_marginal = pragmatic_listener(args.price)
Expand Down
1 change: 0 additions & 1 deletion examples/rsa/schelling_false.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def bob(preference, depth):


def main(args):

# Here Alice and Bob slightly prefer one location over the other a priori
shared_preference = torch.tensor([args.preference])

Expand Down
1 change: 0 additions & 1 deletion examples/rsa/search_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def sample_escape(tr, site):
)

def _fn(*args, **kwargs):

for i in range(int(1e6)):
assert (
not queue.empty()
Expand Down
1 change: 0 additions & 1 deletion examples/rsa/semantic_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ def literal_listener_raw(utterance, qud):


def main(args):

mll = Marginal(literal_listener_raw, num_samples=args.num_samples)

def is_any_qud(world):
Expand Down
8 changes: 1 addition & 7 deletions examples/vae/ss_vae_M2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __init__(
use_cuda=False,
aux_loss_multiplier=None,
):

super().__init__()

# initialize the class with all arguments provided to the constructor
Expand All @@ -68,7 +67,6 @@ def __init__(
self.setup_networks()

def setup_networks(self):

z_dim = self.z_dim
hidden_sizes = self.hidden_layers

Expand Down Expand Up @@ -127,7 +125,6 @@ def model(self, xs, ys=None):
batch_size = xs.size(0)
options = dict(dtype=xs.dtype, device=xs.device)
with pyro.plate("data"):

# sample the handwriting style from the constant prior distribution
prior_loc = torch.zeros(batch_size, self.z_dim, **options)
prior_scale = torch.ones(batch_size, self.z_dim, **options)
Expand Down Expand Up @@ -167,7 +164,6 @@ def guide(self, xs, ys=None):
"""
# inform Pyro that the variables in the batch of xs, ys are conditionally independent
with pyro.plate("data"):

# if the class label (the digit) is not supervised, sample
# (and score) the digit with the variational distribution
# q(y|x) = categorical(alpha(x))
Expand Down Expand Up @@ -245,7 +241,6 @@ def run_inference_for_epoch(data_loaders, losses, periodic_interval_batches):
# count the number of supervised batches seen in this epoch
ctr_sup = 0
for i in range(batches_per_epoch):

# whether this batch is supervised or not
is_supervised = (i % periodic_interval_batches == 1) and ctr_sup < sup_batches

Expand Down Expand Up @@ -277,7 +272,7 @@ def get_accuracy(data_loader, classifier_fn, batch_size):
predictions, actuals = [], []

# use the appropriate data loader
for (xs, ys) in data_loader:
for xs, ys in data_loader:
# use classification function to compute all predictions for each batch
predictions.append(classifier_fn(xs))
actuals.append(ys)
Expand Down Expand Up @@ -370,7 +365,6 @@ def main(args):

# run inference for a certain number of epochs
for i in range(0, args.num_epochs):

# get the losses for an epoch
epoch_losses_sup, epoch_losses_unsup = run_inference_for_epoch(
data_loaders, losses, periodic_interval_batches
Expand Down
2 changes: 0 additions & 2 deletions examples/vae/utils/custom_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,11 @@ def __init__(
else output_activation
)
else:

# we're going to have a bunch of separate layers we can spit out (a tuple of outputs)
out_layers = []

# multiple outputs? handle separately
for out_ix, out_size in enumerate(output_size):

# for a single output object, we create a linear layer and some weights
split_layer = []

Expand Down
1 change: 0 additions & 1 deletion examples/vae/utils/mnist_cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def target_transform(y):
], "invalid train/test option values"

if mode in ["sup", "unsup", "valid"]:

# transform the training data if transformations are provided
if transform is not None:
self.data = transform(self.data.float())
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/funsor/handlers/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,10 @@ def queue(
:param num_samples: optional number of extended traces for extend_fn to return
:returns: stochastic function decorated with poutine logic
"""

# TODO rewrite this to use purpose-built trace/replay handlers
def wrapper(wrapped):
def _fn(*args, **kwargs):

for i in range(max_tries):
assert (
not queue.empty()
Expand Down
2 changes: 0 additions & 2 deletions pyro/contrib/funsor/handlers/named_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __exit__(self, *args, **kwargs):

@staticmethod # only depends on the global _DIM_STACK state, not self
def _pyro_to_data(msg):

(funsor_value,) = msg["args"]
name_to_dim = msg["kwargs"].setdefault("name_to_dim", OrderedDict())
dim_type = msg["kwargs"].setdefault("dim_type", DimType.LOCAL)
Expand All @@ -82,7 +81,6 @@ def _pyro_to_data(msg):

@staticmethod # only depends on the global _DIM_STACK state, not self
def _pyro_to_funsor(msg):

if len(msg["args"]) == 2:
raw_value, output = msg["args"]
else:
Expand Down
1 change: 0 additions & 1 deletion pyro/contrib/funsor/handlers/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def _genvalue(self, key, value_request):
)

def allocate(self, key_to_value_request):

# step 1: split into fresh and non-fresh
key_to_value = OrderedDict()
for key, value_request in tuple(key_to_value_request.items()):
Expand Down
Loading

0 comments on commit 685c7ad

Please sign in to comment.