From 7ea256cb769c5cf12129a4aed77724fa4f3b8136 Mon Sep 17 00:00:00 2001 From: Martin Jankowiak Date: Thu, 24 Aug 2023 15:07:26 +0000 Subject: [PATCH 1/2] use generator arg --- pyro/contrib/cevae/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyro/contrib/cevae/__init__.py b/pyro/contrib/cevae/__init__.py index ee5483e582..ed1111e2ba 100644 --- a/pyro/contrib/cevae/__init__.py +++ b/pyro/contrib/cevae/__init__.py @@ -574,7 +574,8 @@ def fit( self.whiten = PreWhitener(x) dataset = TensorDataset(x, t, y) - dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, + generator=torch.Generator(device=x.device)) logger.info("Training with {} minibatches per epoch".format(len(dataloader))) num_steps = num_epochs * len(dataloader) optim = ClippedAdam( From 97c1792ee78612ea7480b1388a89e7aeef685538 Mon Sep 17 00:00:00 2001 From: Martin Jankowiak Date: Thu, 24 Aug 2023 15:11:13 +0000 Subject: [PATCH 2/2] black --- pyro/contrib/cevae/__init__.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pyro/contrib/cevae/__init__.py b/pyro/contrib/cevae/__init__.py index ed1111e2ba..842a4e2a53 100644 --- a/pyro/contrib/cevae/__init__.py +++ b/pyro/contrib/cevae/__init__.py @@ -574,8 +574,12 @@ def fit( self.whiten = PreWhitener(x) dataset = TensorDataset(x, t, y) - dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, - generator=torch.Generator(device=x.device)) + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + generator=torch.Generator(device=x.device), + ) logger.info("Training with {} minibatches per epoch".format(len(dataloader))) num_steps = num_epochs * len(dataloader) optim = ClippedAdam(