diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 23049358395b9..68197e32b887b 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -79,6 +79,7 @@ def __init__(self, b2: float = 0.999, batch_size: int = 64, **kwargs): super().__init__() + self.latent_dim = latent_dim self.lr = lr self.b1 = b1 @@ -90,9 +91,7 @@ def __init__(self, self.generator = Generator(latent_dim=self.latent_dim, img_shape=mnist_shape) self.discriminator = Discriminator(img_shape=mnist_shape) - # cache for generated images - self.generated_imgs = None - self.last_imgs = None + self.validation_z = torch.randn(8, self.latent_dim) def forward(self, z): return self.generator(z) @@ -102,21 +101,21 @@ def adversarial_loss(self, y_hat, y): def training_step(self, batch, batch_idx, optimizer_idx): imgs, _ = batch - self.last_imgs = imgs + + # sample noise + z = torch.randn(imgs.shape[0], self.latent_dim) + z = z.type_as(imgs) # train generator if optimizer_idx == 0: - # sample noise - z = torch.randn(imgs.shape[0], self.latent_dim) - z = z.type_as(imgs) # generate images self.generated_imgs = self(z) # log sampled images - # sample_imgs = self.generated_imgs[:6] - # grid = torchvision.utils.make_grid(sample_imgs) - # self.logger.experiment.add_image('generated_images', grid, 0) + sample_imgs = self.generated_imgs[:6] + grid = torchvision.utils.make_grid(sample_imgs) + self.logger.experiment.add_image('generated_images', grid, 0) # ground truth result (ie: all fake) # put on GPU because we created this tensor inside training_loop @@ -124,7 +123,7 @@ def training_step(self, batch, batch_idx, optimizer_idx): valid = valid.type_as(imgs) # adversarial loss is binary cross-entropy - g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid) + g_loss = self.adversarial_loss(self.discriminator(self(z)), valid) tqdm_dict = {'g_loss': g_loss} output = OrderedDict({ 'loss': g_loss, @@ -145,10 +144,10 @@ def training_step(self, batch, batch_idx, optimizer_idx): # how well can it label as fake? fake = torch.zeros(imgs.size(0), 1) - fake = fake.type_as(fake) + fake = fake.type_as(imgs) fake_loss = self.adversarial_loss( - self.discriminator(self.generated_imgs.detach()), fake) + self.discriminator(self(z).detach()), fake) # discriminator loss is the average of these d_loss = (real_loss + fake_loss) / 2 @@ -176,8 +175,7 @@ def train_dataloader(self): return DataLoader(dataset, batch_size=self.batch_size) def on_epoch_end(self): - z = torch.randn(8, self.latent_dim) - z = z.type_as(self.last_imgs) + z = self.validation_z.type_as(self.generator.model[0].weight) # log sampled images sample_imgs = self(z) @@ -185,15 +183,17 @@ def on_epoch_end(self): self.logger.experiment.add_image('generated_images', grid, self.current_epoch) -def main(hparams): +def main(args): # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ - model = GAN(hparams) + model = GAN(**vars(args)) # ------------------------ # 2 INIT TRAINER # ------------------------ + # If use distubuted training PyTorch recommends to use DistributedDataParallel. + # See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel trainer = Trainer() # ------------------------