diff --git a/pyro/contrib/cevae/__init__.py b/pyro/contrib/cevae/__init__.py index ee5483e582..842a4e2a53 100644 --- a/pyro/contrib/cevae/__init__.py +++ b/pyro/contrib/cevae/__init__.py @@ -574,7 +574,12 @@ def fit( self.whiten = PreWhitener(x) dataset = TensorDataset(x, t, y) - dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + generator=torch.Generator(device=x.device), + ) logger.info("Training with {} minibatches per epoch".format(len(dataloader))) num_steps = num_epochs * len(dataloader) optim = ClippedAdam(