Skip to content

Commit

Permalink
Add Pix2Pix model (#533)
Browse files Browse the repository at this point in the history
* initial commit

* added components for pix2pix

* pix2pix model created

* Update pl_bolts/models/gans/pix2pix/components.py

Co-authored-by: Aditya Oke <47158509+oke-aditya@users.noreply.github.com>

* yapf consistency

* removed tilde

* added training step and some reformatting

* center_crop torchvision and dropout value 0.5

* renamed variables

* update

* updated requirements file

* refactored and need to fix generator UPSAMPLE

* removed torchvision

* refactored

* updated

* yapf

Co-authored-by: Aniket Maurya <aniket.maurya@gdn-commerce.com>
Co-authored-by: Aditya Oke <47158509+oke-aditya@users.noreply.github.com>
  • Loading branch information
3 people authored Mar 4, 2021
1 parent 8f49ff9 commit 53f5370
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 3 deletions.
6 changes: 4 additions & 2 deletions pl_bolts/models/gans/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from pl_bolts.models.gans.basic.basic_gan_module import GAN # noqa: F401
from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN # noqa: F401
from pl_bolts.models.gans.basic.basic_gan_module import GAN
from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN
from pl_bolts.models.gans.pix2pix.pix2pix_module import Pix2Pix

__all__ = [
"GAN",
"DCGAN",
"Pix2Pix",
]
Empty file.
153 changes: 153 additions & 0 deletions pl_bolts/models/gans/pix2pix/components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
from torch import nn


class UpSampleConv(nn.Module):

def __init__(
self,
in_channels,
out_channels,
kernel=4,
strides=2,
padding=1,
activation=True,
batchnorm=True,
dropout=False
):
super().__init__()
self.activation = activation
self.batchnorm = batchnorm
self.dropout = dropout

self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel, strides, padding)

if batchnorm:
self.bn = nn.BatchNorm2d(out_channels)

if activation:
self.act = nn.ReLU(True)

if dropout:
self.drop = nn.Dropout2d(0.5)

def forward(self, x):
x = self.deconv(x)
if self.batchnorm:
x = self.bn(x)

if self.dropout:
x = self.drop(x)
return x


class DownSampleConv(nn.Module):

def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True):
"""
Paper details:
- C64-C128-C256-C512-C512-C512-C512-C512
- All convolutions are 4×4 spatial filters applied with stride 2
- Convolutions in the encoder downsample by a factor of 2
"""
super().__init__()
self.activation = activation
self.batchnorm = batchnorm

self.conv = nn.Conv2d(in_channels, out_channels, kernel, strides, padding)

if batchnorm:
self.bn = nn.BatchNorm2d(out_channels)

if activation:
self.act = nn.LeakyReLU(0.2)

def forward(self, x):
x = self.conv(x)
if self.batchnorm:
x = self.bn(x)
if self.activation:
x = self.act(x)
return x


class Generator(nn.Module):

def __init__(self, in_channels, out_channels):
"""
Paper details:
- Encoder: C64-C128-C256-C512-C512-C512-C512-C512
- All convolutions are 4×4 spatial filters applied with stride 2
- Convolutions in the encoder downsample by a factor of 2
- Decoder: CD512-CD1024-CD1024-C1024-C1024-C512 -C256-C128
"""
super().__init__()

# encoder/donwsample convs
self.encoders = [
DownSampleConv(in_channels, 64, batchnorm=False), # bs x 64 x 128 x 128
DownSampleConv(64, 128), # bs x 128 x 64 x 64
DownSampleConv(128, 256), # bs x 256 x 32 x 32
DownSampleConv(256, 512), # bs x 512 x 16 x 16
DownSampleConv(512, 512), # bs x 512 x 8 x 8
DownSampleConv(512, 512), # bs x 512 x 4 x 4
DownSampleConv(512, 512), # bs x 512 x 2 x 2
DownSampleConv(512, 512, batchnorm=False), # bs x 512 x 1 x 1
]

# decoder/upsample convs
self.decoders = [
UpSampleConv(512, 512, dropout=True), # bs x 512 x 2 x 2
UpSampleConv(1024, 512, dropout=True), # bs x 512 x 4 x 4
UpSampleConv(1024, 512, dropout=True), # bs x 512 x 8 x 8
UpSampleConv(1024, 512), # bs x 512 x 16 x 16
UpSampleConv(1024, 256), # bs x 256 x 32 x 32
UpSampleConv(512, 128), # bs x 128 x 64 x 64
UpSampleConv(256, 64), # bs x 64 x 128 x 128
]
self.decoder_channels = [512, 512, 512, 512, 256, 128, 64]
self.final_conv = nn.ConvTranspose2d(64, out_channels, kernel_size=4, stride=2, padding=1)
self.tanh = nn.Tanh()

self.encoders = nn.ModuleList(self.encoders)
self.decoders = nn.ModuleList(self.decoders)

def forward(self, x):
skips_cons = []
for encoder in self.encoders:
x = encoder(x)

skips_cons.append(x)

skips_cons = list(reversed(skips_cons[:-1]))
decoders = self.decoders[:-1]

for decoder, skip in zip(decoders, skips_cons):
x = decoder(x)
# print(x.shape, skip.shape)
x = torch.cat((x, skip), axis=1)

x = self.decoders[-1](x)
# print(x.shape)
x = self.final_conv(x)
return self.tanh(x)


class PatchGAN(nn.Module):

def __init__(self, input_channels):
super().__init__()
self.d1 = DownSampleConv(input_channels, 64, batchnorm=False)
self.d2 = DownSampleConv(64, 128)
self.d3 = DownSampleConv(128, 256)
self.d4 = DownSampleConv(256, 512)
self.final = nn.Conv2d(512, 1, kernel_size=1)

def forward(self, x, y):
x = torch.cat([x, y], axis=1)
x0 = self.d1(x)
x1 = self.d2(x0)
x2 = self.d3(x1)
x3 = self.d4(x2)
xn = self.final(x3)
return xn
78 changes: 78 additions & 0 deletions pl_bolts/models/gans/pix2pix/pix2pix_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytorch_lightning as pl
import torch
from torch import nn

from pl_bolts.models.gans.pix2pix.components import Generator, PatchGAN


def _weights_init(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)


class Pix2Pix(pl.LightningModule):

def __init__(self, in_channels, out_channels, learning_rate=0.0002, lambda_recon=200):

super().__init__()
self.save_hyperparameters()

self.gen = Generator(in_channels, out_channels)
self.patch_gan = PatchGAN(in_channels + out_channels)

# intializing weights
self.gen = self.gen.apply(_weights_init)
self.patch_gan = self.patch_gan.apply(_weights_init)

self.adversarial_criterion = nn.BCEWithLogitsLoss()
self.recon_criterion = nn.L1Loss()

def _gen_step(self, real_images, conditioned_images):
# Pix2Pix has adversarial and a reconstruction loss
# First calculate the adversarial loss
fake_images = self.gen(conditioned_images)
disc_logits = self.patch_gan(fake_images, conditioned_images)
adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits))

# calculate reconstruction loss
recon_loss = self.recon_criterion(fake_images, real_images)
lambda_recon = self.hparams.lambda_recon

return adversarial_loss + lambda_recon * recon_loss

def _disc_step(self, real_images, conditioned_images):
fake_images = self.gen(conditioned_images).detach()
fake_logits = self.patch_gan(fake_images, conditioned_images)

real_logits = self.patch_gan(real_images, conditioned_images)

fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits))
real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits))
return (real_loss + fake_loss) / 2

def configure_optimizers(self):
lr = self.hparams.learning_rate
gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr)
disc_opt = torch.optim.Adam(self.patch_gan.parameters(), lr=lr)
return disc_opt, gen_opt

def training_step(self, batch, batch_idx, optimizer_idx):
real, condition = batch

loss = None
if optimizer_idx == 0:
loss = self._disc_step(real, condition)
self.log('PatchGAN Loss', loss)
elif optimizer_idx == 1:
loss = self._gen_step(real, condition)
self.log('Generator Loss', loss)

return loss


if __name__ == '__main__':
pix2pix = Pix2Pix(3, 3)
print(pix2pix(torch.randn(1, 3, 256, 256)).shape)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch>=1.6
pytorch-lightning>=1.1.1, <1.2
pytorch-lightning>=1.1.1, <1.2

0 comments on commit 53f5370

Please sign in to comment.