Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pix2Pix model #533

Merged
merged 18 commits into from
Mar 4, 2021
6 changes: 4 additions & 2 deletions pl_bolts/models/gans/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from pl_bolts.models.gans.basic.basic_gan_module import GAN # noqa: F401
from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN # noqa: F401
from pl_bolts.models.gans.basic.basic_gan_module import GAN
from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN
from pl_bolts.models.gans.pix2pix.pix2pix_module import Pix2Pix

__all__ = [
"GAN",
"DCGAN",
"Pix2Pix",
]
Empty file.
153 changes: 153 additions & 0 deletions pl_bolts/models/gans/pix2pix/components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
from torch import nn


class UpSampleConv(nn.Module):

def __init__(
self,
in_channels,
out_channels,
kernel=4,
strides=2,
padding=1,
activation=True,
batchnorm=True,
dropout=False
):
super().__init__()
self.activation = activation
self.batchnorm = batchnorm
self.dropout = dropout

self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel, strides, padding)

if batchnorm:
self.bn = nn.BatchNorm2d(out_channels)

if activation:
self.act = nn.ReLU(True)

if dropout:
self.drop = nn.Dropout2d(0.5)

def forward(self, x):
x = self.deconv(x)
if self.batchnorm:
x = self.bn(x)

if self.dropout:
x = self.drop(x)
return x


class DownSampleConv(nn.Module):

def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True):
"""
Paper details:
- C64-C128-C256-C512-C512-C512-C512-C512
- All convolutions are 4×4 spatial filters applied with stride 2
- Convolutions in the encoder downsample by a factor of 2
"""
super().__init__()
self.activation = activation
self.batchnorm = batchnorm

self.conv = nn.Conv2d(in_channels, out_channels, kernel, strides, padding)

if batchnorm:
self.bn = nn.BatchNorm2d(out_channels)

if activation:
self.act = nn.LeakyReLU(0.2)

def forward(self, x):
x = self.conv(x)
if self.batchnorm:
x = self.bn(x)
if self.activation:
x = self.act(x)
return x


class Generator(nn.Module):

def __init__(self, in_channels, out_channels):
"""
Paper details:
- Encoder: C64-C128-C256-C512-C512-C512-C512-C512
- All convolutions are 4×4 spatial filters applied with stride 2
- Convolutions in the encoder downsample by a factor of 2
- Decoder: CD512-CD1024-CD1024-C1024-C1024-C512 -C256-C128
"""
super().__init__()

# encoder/donwsample convs
self.encoders = [
DownSampleConv(in_channels, 64, batchnorm=False), # bs x 64 x 128 x 128
DownSampleConv(64, 128), # bs x 128 x 64 x 64
DownSampleConv(128, 256), # bs x 256 x 32 x 32
DownSampleConv(256, 512), # bs x 512 x 16 x 16
DownSampleConv(512, 512), # bs x 512 x 8 x 8
DownSampleConv(512, 512), # bs x 512 x 4 x 4
DownSampleConv(512, 512), # bs x 512 x 2 x 2
DownSampleConv(512, 512, batchnorm=False), # bs x 512 x 1 x 1
]

# decoder/upsample convs
self.decoders = [
UpSampleConv(512, 512, dropout=True), # bs x 512 x 2 x 2
UpSampleConv(1024, 512, dropout=True), # bs x 512 x 4 x 4
UpSampleConv(1024, 512, dropout=True), # bs x 512 x 8 x 8
UpSampleConv(1024, 512), # bs x 512 x 16 x 16
UpSampleConv(1024, 256), # bs x 256 x 32 x 32
UpSampleConv(512, 128), # bs x 128 x 64 x 64
UpSampleConv(256, 64), # bs x 64 x 128 x 128
]
self.decoder_channels = [512, 512, 512, 512, 256, 128, 64]
self.final_conv = nn.ConvTranspose2d(64, out_channels, kernel_size=4, stride=2, padding=1)
self.tanh = nn.Tanh()

self.encoders = nn.ModuleList(self.encoders)
self.decoders = nn.ModuleList(self.decoders)

def forward(self, x):
skips_cons = []
for encoder in self.encoders:
x = encoder(x)

skips_cons.append(x)

skips_cons = list(reversed(skips_cons[:-1]))
decoders = self.decoders[:-1]

for decoder, skip in zip(decoders, skips_cons):
x = decoder(x)
# print(x.shape, skip.shape)
x = torch.cat((x, skip), axis=1)

x = self.decoders[-1](x)
# print(x.shape)
x = self.final_conv(x)
return self.tanh(x)


class PatchGAN(nn.Module):

def __init__(self, input_channels):
super().__init__()
self.d1 = DownSampleConv(input_channels, 64, batchnorm=False)
self.d2 = DownSampleConv(64, 128)
self.d3 = DownSampleConv(128, 256)
self.d4 = DownSampleConv(256, 512)
self.final = nn.Conv2d(512, 1, kernel_size=1)

def forward(self, x, y):
x = torch.cat([x, y], axis=1)
x0 = self.d1(x)
x1 = self.d2(x0)
x2 = self.d3(x1)
x3 = self.d4(x2)
xn = self.final(x3)
return xn
78 changes: 78 additions & 0 deletions pl_bolts/models/gans/pix2pix/pix2pix_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytorch_lightning as pl
import torch
from torch import nn

from pl_bolts.models.gans.pix2pix.components import Generator, PatchGAN


def _weights_init(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)


class Pix2Pix(pl.LightningModule):

def __init__(self, in_channels, out_channels, learning_rate=0.0002, lambda_recon=200):

super().__init__()
self.save_hyperparameters()

self.gen = Generator(in_channels, out_channels)
self.patch_gan = PatchGAN(in_channels + out_channels)

# intializing weights
self.gen = self.gen.apply(_weights_init)
self.patch_gan = self.patch_gan.apply(_weights_init)

self.adversarial_criterion = nn.BCEWithLogitsLoss()
self.recon_criterion = nn.L1Loss()

def _gen_step(self, real_images, conditioned_images):
# Pix2Pix has adversarial and a reconstruction loss
# First calculate the adversarial loss
fake_images = self.gen(conditioned_images)
disc_logits = self.patch_gan(fake_images, conditioned_images)
adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits))

# calculate reconstruction loss
recon_loss = self.recon_criterion(fake_images, real_images)
lambda_recon = self.hparams.lambda_recon

return adversarial_loss + lambda_recon * recon_loss

def _disc_step(self, real_images, conditioned_images):
fake_images = self.gen(conditioned_images).detach()
fake_logits = self.patch_gan(fake_images, conditioned_images)

real_logits = self.patch_gan(real_images, conditioned_images)

fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits))
real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits))
return (real_loss + fake_loss) / 2

def configure_optimizers(self):
lr = self.hparams.learning_rate
gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr)
disc_opt = torch.optim.Adam(self.patch_gan.parameters(), lr=lr)
return disc_opt, gen_opt

def training_step(self, batch, batch_idx, optimizer_idx):
real, condition = batch

loss = None
if optimizer_idx == 0:
loss = self._disc_step(real, condition)
self.log('PatchGAN Loss', loss)
elif optimizer_idx == 1:
loss = self._gen_step(real, condition)
self.log('Generator Loss', loss)

return loss


if __name__ == '__main__':
pix2pix = Pix2Pix(3, 3)
print(pix2pix(torch.randn(1, 3, 256, 256)).shape)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch>=1.6
pytorch-lightning>=1.1.1, <1.2
pytorch-lightning>=1.1.1, <1.2