Skip to content

Commit

Permalink
Bugfix/fix gan example (#2019)
Browse files Browse the repository at this point in the history
* 🐛 fixed fake example type assigning and hparams arg

* fixed GAN example to work with dp, ddp., ddp_cpu

* Update generative_adversarial_net.py

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
Artem Lobantsev and williamFalcon authored May 31, 2020
1 parent 0e37e8c commit 55fdfe3
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions pl_examples/domain_templates/generative_adversarial_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self,
b2: float = 0.999,
batch_size: int = 64, **kwargs):
super().__init__()

self.latent_dim = latent_dim
self.lr = lr
self.b1 = b1
Expand All @@ -90,9 +91,7 @@ def __init__(self,
self.generator = Generator(latent_dim=self.latent_dim, img_shape=mnist_shape)
self.discriminator = Discriminator(img_shape=mnist_shape)

# cache for generated images
self.generated_imgs = None
self.last_imgs = None
self.validation_z = torch.randn(8, self.latent_dim)

def forward(self, z):
return self.generator(z)
Expand All @@ -102,29 +101,29 @@ def adversarial_loss(self, y_hat, y):

def training_step(self, batch, batch_idx, optimizer_idx):
imgs, _ = batch
self.last_imgs = imgs

# sample noise
z = torch.randn(imgs.shape[0], self.latent_dim)
z = z.type_as(imgs)

# train generator
if optimizer_idx == 0:
# sample noise
z = torch.randn(imgs.shape[0], self.latent_dim)
z = z.type_as(imgs)

# generate images
self.generated_imgs = self(z)

# log sampled images
# sample_imgs = self.generated_imgs[:6]
# grid = torchvision.utils.make_grid(sample_imgs)
# self.logger.experiment.add_image('generated_images', grid, 0)
sample_imgs = self.generated_imgs[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('generated_images', grid, 0)

# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)

# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
tqdm_dict = {'g_loss': g_loss}
output = OrderedDict({
'loss': g_loss,
Expand All @@ -145,10 +144,10 @@ def training_step(self, batch, batch_idx, optimizer_idx):

# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
fake = fake.type_as(fake)
fake = fake.type_as(imgs)

fake_loss = self.adversarial_loss(
self.discriminator(self.generated_imgs.detach()), fake)
self.discriminator(self(z).detach()), fake)

# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2
Expand Down Expand Up @@ -176,24 +175,25 @@ def train_dataloader(self):
return DataLoader(dataset, batch_size=self.batch_size)

def on_epoch_end(self):
z = torch.randn(8, self.latent_dim)
z = z.type_as(self.last_imgs)
z = self.validation_z.type_as(self.generator.model[0].weight)

# log sampled images
sample_imgs = self(z)
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('generated_images', grid, self.current_epoch)


def main(hparams):
def main(args):
# ------------------------
# 1 INIT LIGHTNING MODEL
# ------------------------
model = GAN(hparams)
model = GAN(**vars(args))

# ------------------------
# 2 INIT TRAINER
# ------------------------
# If use distubuted training PyTorch recommends to use DistributedDataParallel.
# See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
trainer = Trainer()

# ------------------------
Expand Down

0 comments on commit 55fdfe3

Please sign in to comment.