-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial commit * added components for pix2pix * pix2pix model created * Update pl_bolts/models/gans/pix2pix/components.py Co-authored-by: Aditya Oke <47158509+oke-aditya@users.noreply.github.com> * yapf consistency * removed tilde * added training step and some reformatting * center_crop torchvision and dropout value 0.5 * renamed variables * update * updated requirements file * refactored and need to fix generator UPSAMPLE * removed torchvision * refactored * updated * yapf Co-authored-by: Aniket Maurya <aniket.maurya@gdn-commerce.com> Co-authored-by: Aditya Oke <47158509+oke-aditya@users.noreply.github.com>
- Loading branch information
1 parent
8f49ff9
commit 53f5370
Showing
5 changed files
with
236 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
from pl_bolts.models.gans.basic.basic_gan_module import GAN # noqa: F401 | ||
from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN # noqa: F401 | ||
from pl_bolts.models.gans.basic.basic_gan_module import GAN | ||
from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN | ||
from pl_bolts.models.gans.pix2pix.pix2pix_module import Pix2Pix | ||
|
||
__all__ = [ | ||
"GAN", | ||
"DCGAN", | ||
"Pix2Pix", | ||
] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class UpSampleConv(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
in_channels, | ||
out_channels, | ||
kernel=4, | ||
strides=2, | ||
padding=1, | ||
activation=True, | ||
batchnorm=True, | ||
dropout=False | ||
): | ||
super().__init__() | ||
self.activation = activation | ||
self.batchnorm = batchnorm | ||
self.dropout = dropout | ||
|
||
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel, strides, padding) | ||
|
||
if batchnorm: | ||
self.bn = nn.BatchNorm2d(out_channels) | ||
|
||
if activation: | ||
self.act = nn.ReLU(True) | ||
|
||
if dropout: | ||
self.drop = nn.Dropout2d(0.5) | ||
|
||
def forward(self, x): | ||
x = self.deconv(x) | ||
if self.batchnorm: | ||
x = self.bn(x) | ||
|
||
if self.dropout: | ||
x = self.drop(x) | ||
return x | ||
|
||
|
||
class DownSampleConv(nn.Module): | ||
|
||
def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True): | ||
""" | ||
Paper details: | ||
- C64-C128-C256-C512-C512-C512-C512-C512 | ||
- All convolutions are 4×4 spatial filters applied with stride 2 | ||
- Convolutions in the encoder downsample by a factor of 2 | ||
""" | ||
super().__init__() | ||
self.activation = activation | ||
self.batchnorm = batchnorm | ||
|
||
self.conv = nn.Conv2d(in_channels, out_channels, kernel, strides, padding) | ||
|
||
if batchnorm: | ||
self.bn = nn.BatchNorm2d(out_channels) | ||
|
||
if activation: | ||
self.act = nn.LeakyReLU(0.2) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
if self.batchnorm: | ||
x = self.bn(x) | ||
if self.activation: | ||
x = self.act(x) | ||
return x | ||
|
||
|
||
class Generator(nn.Module): | ||
|
||
def __init__(self, in_channels, out_channels): | ||
""" | ||
Paper details: | ||
- Encoder: C64-C128-C256-C512-C512-C512-C512-C512 | ||
- All convolutions are 4×4 spatial filters applied with stride 2 | ||
- Convolutions in the encoder downsample by a factor of 2 | ||
- Decoder: CD512-CD1024-CD1024-C1024-C1024-C512 -C256-C128 | ||
""" | ||
super().__init__() | ||
|
||
# encoder/donwsample convs | ||
self.encoders = [ | ||
DownSampleConv(in_channels, 64, batchnorm=False), # bs x 64 x 128 x 128 | ||
DownSampleConv(64, 128), # bs x 128 x 64 x 64 | ||
DownSampleConv(128, 256), # bs x 256 x 32 x 32 | ||
DownSampleConv(256, 512), # bs x 512 x 16 x 16 | ||
DownSampleConv(512, 512), # bs x 512 x 8 x 8 | ||
DownSampleConv(512, 512), # bs x 512 x 4 x 4 | ||
DownSampleConv(512, 512), # bs x 512 x 2 x 2 | ||
DownSampleConv(512, 512, batchnorm=False), # bs x 512 x 1 x 1 | ||
] | ||
|
||
# decoder/upsample convs | ||
self.decoders = [ | ||
UpSampleConv(512, 512, dropout=True), # bs x 512 x 2 x 2 | ||
UpSampleConv(1024, 512, dropout=True), # bs x 512 x 4 x 4 | ||
UpSampleConv(1024, 512, dropout=True), # bs x 512 x 8 x 8 | ||
UpSampleConv(1024, 512), # bs x 512 x 16 x 16 | ||
UpSampleConv(1024, 256), # bs x 256 x 32 x 32 | ||
UpSampleConv(512, 128), # bs x 128 x 64 x 64 | ||
UpSampleConv(256, 64), # bs x 64 x 128 x 128 | ||
] | ||
self.decoder_channels = [512, 512, 512, 512, 256, 128, 64] | ||
self.final_conv = nn.ConvTranspose2d(64, out_channels, kernel_size=4, stride=2, padding=1) | ||
self.tanh = nn.Tanh() | ||
|
||
self.encoders = nn.ModuleList(self.encoders) | ||
self.decoders = nn.ModuleList(self.decoders) | ||
|
||
def forward(self, x): | ||
skips_cons = [] | ||
for encoder in self.encoders: | ||
x = encoder(x) | ||
|
||
skips_cons.append(x) | ||
|
||
skips_cons = list(reversed(skips_cons[:-1])) | ||
decoders = self.decoders[:-1] | ||
|
||
for decoder, skip in zip(decoders, skips_cons): | ||
x = decoder(x) | ||
# print(x.shape, skip.shape) | ||
x = torch.cat((x, skip), axis=1) | ||
|
||
x = self.decoders[-1](x) | ||
# print(x.shape) | ||
x = self.final_conv(x) | ||
return self.tanh(x) | ||
|
||
|
||
class PatchGAN(nn.Module): | ||
|
||
def __init__(self, input_channels): | ||
super().__init__() | ||
self.d1 = DownSampleConv(input_channels, 64, batchnorm=False) | ||
self.d2 = DownSampleConv(64, 128) | ||
self.d3 = DownSampleConv(128, 256) | ||
self.d4 = DownSampleConv(256, 512) | ||
self.final = nn.Conv2d(512, 1, kernel_size=1) | ||
|
||
def forward(self, x, y): | ||
x = torch.cat([x, y], axis=1) | ||
x0 = self.d1(x) | ||
x1 = self.d2(x0) | ||
x2 = self.d3(x1) | ||
x3 = self.d4(x2) | ||
xn = self.final(x3) | ||
return xn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import pytorch_lightning as pl | ||
import torch | ||
from torch import nn | ||
|
||
from pl_bolts.models.gans.pix2pix.components import Generator, PatchGAN | ||
|
||
|
||
def _weights_init(m): | ||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): | ||
torch.nn.init.normal_(m.weight, 0.0, 0.02) | ||
if isinstance(m, nn.BatchNorm2d): | ||
torch.nn.init.normal_(m.weight, 0.0, 0.02) | ||
torch.nn.init.constant_(m.bias, 0) | ||
|
||
|
||
class Pix2Pix(pl.LightningModule): | ||
|
||
def __init__(self, in_channels, out_channels, learning_rate=0.0002, lambda_recon=200): | ||
|
||
super().__init__() | ||
self.save_hyperparameters() | ||
|
||
self.gen = Generator(in_channels, out_channels) | ||
self.patch_gan = PatchGAN(in_channels + out_channels) | ||
|
||
# intializing weights | ||
self.gen = self.gen.apply(_weights_init) | ||
self.patch_gan = self.patch_gan.apply(_weights_init) | ||
|
||
self.adversarial_criterion = nn.BCEWithLogitsLoss() | ||
self.recon_criterion = nn.L1Loss() | ||
|
||
def _gen_step(self, real_images, conditioned_images): | ||
# Pix2Pix has adversarial and a reconstruction loss | ||
# First calculate the adversarial loss | ||
fake_images = self.gen(conditioned_images) | ||
disc_logits = self.patch_gan(fake_images, conditioned_images) | ||
adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits)) | ||
|
||
# calculate reconstruction loss | ||
recon_loss = self.recon_criterion(fake_images, real_images) | ||
lambda_recon = self.hparams.lambda_recon | ||
|
||
return adversarial_loss + lambda_recon * recon_loss | ||
|
||
def _disc_step(self, real_images, conditioned_images): | ||
fake_images = self.gen(conditioned_images).detach() | ||
fake_logits = self.patch_gan(fake_images, conditioned_images) | ||
|
||
real_logits = self.patch_gan(real_images, conditioned_images) | ||
|
||
fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits)) | ||
real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits)) | ||
return (real_loss + fake_loss) / 2 | ||
|
||
def configure_optimizers(self): | ||
lr = self.hparams.learning_rate | ||
gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr) | ||
disc_opt = torch.optim.Adam(self.patch_gan.parameters(), lr=lr) | ||
return disc_opt, gen_opt | ||
|
||
def training_step(self, batch, batch_idx, optimizer_idx): | ||
real, condition = batch | ||
|
||
loss = None | ||
if optimizer_idx == 0: | ||
loss = self._disc_step(real, condition) | ||
self.log('PatchGAN Loss', loss) | ||
elif optimizer_idx == 1: | ||
loss = self._gen_step(real, condition) | ||
self.log('Generator Loss', loss) | ||
|
||
return loss | ||
|
||
|
||
if __name__ == '__main__': | ||
pix2pix = Pix2Pix(3, 3) | ||
print(pix2pix(torch.randn(1, 3, 256, 256)).shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
torch>=1.6 | ||
pytorch-lightning>=1.1.1, <1.2 | ||
pytorch-lightning>=1.1.1, <1.2 |