Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pix2Pix model #533

Merged
merged 18 commits into from
Mar 4, 2021
6 changes: 4 additions & 2 deletions pl_bolts/models/gans/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from pl_bolts.models.gans.basic.basic_gan_module import GAN # noqa: F401
from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN # noqa: F401
from pl_bolts.models.gans.basic.basic_gan_module import GAN
from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN
from pl_bolts.models.gans.pix2pix.pix2pix_module import Pix2Pix

__all__ = [
"GAN",
"DCGAN",
"Pix2Pix",
]
Empty file.
168 changes: 168 additions & 0 deletions pl_bolts/models/gans/pix2pix/components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import torch
from torch import nn


def _center_crop(image, new_shape):
h, w = image.shape[-2:]
n_h, n_w = new_shape[-2:]
cy, cx = int(h / 2), int(w / 2)
xmin, ymin = cx - n_w // 2, cy - n_h // 2
xmax, ymax = xmin + n_w, ymin + n_h
cropped_image = image[..., xmin:xmax, ymin:ymax]
return cropped_image


class ConvBlock(nn.Module):
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, in_channels, out_channels, use_dropout=False, use_bn=True):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.activation = nn.LeakyReLU(0.2)

if use_bn:
self.batchnorm = nn.BatchNorm2d(out_channels)
self.use_bn = use_bn

if use_dropout:
self.dropout = nn.Dropout()
self.use_dropout = use_dropout

def forward(self, x):
x = self.conv1(x)
if self.use_bn:
x = self.batchnorm(x)
if self.use_dropout:
x = self.dropout(x)
x = self.activation(x)
return x


class UpSampleConv(nn.Module):

def __init__(self, input_channels, use_dropout=False, use_bn=True):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just confirm once if Upsample is done using nn.Upsample or nn.ConvTranspose2d both work fine. I haven't read Pix2Pix paper so let me check once.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thank you for the review. In section 6 of the Pix2Pix paper authors have mentioned that they upsampled the tensors by a factor of 2 but they haven't exactly mentioned if Transposed Conv is used or Upsample followed by Conv layer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quickly checked the paper and found that the PyTorch implementation linked from the author's Lua implementation uses nn.ConvTranspose2d, so shall we follow that architecture unless someone has a strong opinion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I too confirmed that it is nn.ConvTranspose2d. I have referred TensorFlow docs, which give a really nice implementation.

self.conv1 = nn.Conv2d(input_channels, input_channels // 2, kernel_size=2)
self.conv2 = nn.Conv2d(input_channels, input_channels // 2, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(input_channels // 2, input_channels // 2, kernel_size=2, padding=1)
if use_bn:
self.batchnorm = nn.BatchNorm2d(input_channels // 2)
self.use_bn = use_bn
self.activation = nn.ReLU()
if use_dropout:
self.dropout = nn.Dropout()
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
self.use_dropout = use_dropout

def forward(self, x, skip_con_x):

x = self.upsample(x)
x = self.conv1(x)
skip_con_x = _center_crop(skip_con_x, x.shape)
x = torch.cat([x, skip_con_x], axis=1)
x = self.conv2(x)
if self.use_bn:
x = self.batchnorm(x)
if self.use_dropout:
x = self.dropout(x)
x = self.activation(x)
x = self.conv3(x)
if self.use_bn:
x = self.batchnorm(x)
if self.use_dropout:
x = self.dropout(x)
x = self.activation(x)
return x


class DownSampleConv(nn.Module):

def __init__(self, in_channels, use_dropout=False, use_bn=False):
super().__init__()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

if use_bn:
self.batchnorm = nn.BatchNorm2d(in_channels * 2)
self.use_bn = use_bn

if use_dropout:
self.dropout = nn.Dropout()
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
self.use_dropout = use_dropout

self.conv_block1 = ConvBlock(in_channels, in_channels * 2, use_dropout, use_bn)
self.conv_block2 = ConvBlock(in_channels * 2, in_channels * 2, use_dropout, use_bn)

def forward(self, x):
x = self.conv_block1(x)
x = self.conv_block2(x)
x = self.maxpool(x)
return x


class Generator(nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels=32, depth=6):
super().__init__()

self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)

self.conv_final = nn.Conv2d(hidden_channels,
out_channels,
kernel_size=1)
self.depth = depth

self.contracting_layers = []
self.expanding_layers = []
self.sigmoid = nn.Sigmoid()

# encoding/contracting path of the Generator
for i in range(depth):
down_sample_conv = DownSampleConv(hidden_channels * 2 ** i,
use_dropout=(True if i < 3 else False))
self.contracting_layers.append(down_sample_conv)

# Upsampling/Expanding path of the Generator
for i in range(depth):
upsample_conv = UpSampleConv(hidden_channels * 2 ** (i + 1))
self.expanding_layers.append(upsample_conv)

self.contracting_layers = nn.ModuleList(self.contracting_layers)
self.expanding_layers = nn.ModuleList(self.expanding_layers)

def forward(self, x):
depth = self.depth
contractive_x = []

x = self.conv1(x)
contractive_x.append(x)

for i in range(depth):
x = self.contracting_layers[i](x)
print(x.shape)
contractive_x.append(x)

for i in range(depth - 1, -1, -1):
x = self.expanding_layers[i](x, contractive_x[i])
print(x.shape)
x = self.conv_final(x)

return self.sigmoid(x)


class PatchGAN(nn.Module):

def __init__(self, input_channels, hidden_channels=8):
super().__init__()
self.conv1 = nn.Conv2d(input_channels, hidden_channels, kernel_size=1)
self.contract1 = DownSampleConv(hidden_channels, use_bn=False)
self.contract2 = DownSampleConv(hidden_channels * 2)
self.contract3 = DownSampleConv(hidden_channels * 4)
self.contract4 = DownSampleConv(hidden_channels * 8)
self.final = nn.Conv2d(hidden_channels * 16, 1, kernel_size=1)

def forward(self, x, y):
x = torch.cat([x, y], axis=1)
x0 = self.conv1(x)
x1 = self.contract1(x0)
x2 = self.contract2(x1)
x3 = self.contract3(x2)
x4 = self.contract4(x3)
xn = self.final(x4)
return xn
78 changes: 78 additions & 0 deletions pl_bolts/models/gans/pix2pix/pix2pix_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytorch_lightning as pl
import torch
from torch import nn

from pl_bolts.models.gans.pix2pix.components import Generator, PatchGAN


def _weights_init(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)


class Pix2Pix(pl.LightningModule):
def __init__(self,
in_channels,
out_channels,
hidden_channels=32,
depth=6,
learning_rate=0.0002,
lambda_recon=200):

super().__init__()
self.save_hyperparameters()

self.gen = Generator(in_channels, out_channels, hidden_channels, depth)
self.patch_gan = PatchGAN(in_channels + out_channels, hidden_channels=8)

# intializing weights
self.gen = self.gen.apply(_weights_init)
self.patch_gan = self.patch_gan.apply(_weights_init)

self.adversarial_criterion = nn.BCEWithLogitsLoss()
self.recon_criterion = nn.L1Loss()

def _gen_step(self, real_images, conditioned_images):
# Pix2Pix has adversarial and a reconstruction loss
# First calculate the adversarial loss
fake_images = self.gen(conditioned_images)
disc_logits = self.patch_gan(fake_images, conditioned_images)
adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits))

# calculate reconstruction loss
recon_loss = self.recon_criterion(fake_images, real_images)
lambda_recon = self.hparams.lambda_recon

return adversarial_loss + lambda_recon * recon_loss

def _disc_step(self, real_images, conditioned_images):
fake_images = self.gen(conditioned_images).detach()
fake_logits = self.patch_gan(fake_images, conditioned_images)

real_logits = self.patch_gan(real_images, conditioned_images)

fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits))
real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits))
return (real_loss + fake_loss) / 2

def configure_optimizers(self):
lr = self.hparams.learning_rate
gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr)
disc_opt = torch.optim.Adam(self.patch_gan.parameters(), lr=lr)
return disc_opt, gen_opt

def training_step(self, batch, batch_idx, optimizer_idx):
real, conditioned = batch

loss = None
if optimizer_idx == 0:
loss = self._disc_step(real, conditioned)
self.log('PatchGAN Loss', loss)
elif optimizer_idx == 1:
loss = self._gen_step(real, conditioned)
self.log('Generator Loss', loss)

return loss