-
Notifications
You must be signed in to change notification settings - Fork 652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🐞fix support for non-square images #204
Merged
Merged
Changes from 1 commit
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,16 +12,18 @@ | |
|
||
|
||
import math | ||
from typing import Tuple | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import Tensor, nn | ||
|
||
|
||
class Encoder(nn.Module): | ||
"""Encoder Network. | ||
|
||
Args: | ||
input_size (int): Size of input image | ||
input_size (Tuple[int, int]): Size of input image | ||
latent_vec_size (int): Size of latent vector z | ||
num_input_channels (int): Number of input channels in the image | ||
n_features (int): Number of features per convolution layer | ||
|
@@ -31,7 +33,7 @@ class Encoder(nn.Module): | |
|
||
def __init__( | ||
self, | ||
input_size: int, | ||
input_size: Tuple[int, int], | ||
latent_vec_size: int, | ||
num_input_channels: int, | ||
n_features: int, | ||
|
@@ -40,13 +42,14 @@ def __init__( | |
): | ||
super().__init__() | ||
|
||
assert input_size % 16 == 0, "Input size should be a multiple of 16" | ||
assert input_size[0] % 16 == 0 and input_size[1] % 16 == 0, "Input size should be a multiple of 16" | ||
|
||
self.input_layers = nn.Sequential() | ||
|
||
self.padding = self._compute_padding(input_size) | ||
self.input_layers.add_module( | ||
f"initial-conv-{num_input_channels}-{n_features}", | ||
nn.Conv2d(num_input_channels, n_features, kernel_size=4, stride=2, padding=1, bias=False), | ||
nn.Conv2d(num_input_channels, n_features, kernel_size=4, stride=2, padding=self.padding, bias=False), | ||
) | ||
self.input_layers.add_module(f"initial-relu-{n_features}", nn.LeakyReLU(0.2, inplace=True)) | ||
|
||
|
@@ -63,8 +66,8 @@ def __init__( | |
|
||
# Create pyramid features to reach latent vector | ||
self.pyramid_features = nn.Sequential() | ||
input_size = input_size // 2 | ||
while input_size > 4: | ||
pyramid_dim = min(*input_size) // 2 # Use the smaller dimension to create pyramid. | ||
while pyramid_dim > 4: | ||
in_features = n_features | ||
out_features = n_features * 2 | ||
self.pyramid_features.add_module( | ||
|
@@ -74,14 +77,33 @@ def __init__( | |
self.pyramid_features.add_module(f"pyramid-{out_features}-batchnorm", nn.BatchNorm2d(out_features)) | ||
self.pyramid_features.add_module(f"pyramid-{out_features}-relu", nn.LeakyReLU(0.2, inplace=True)) | ||
n_features = out_features | ||
input_size = input_size // 2 | ||
pyramid_dim = pyramid_dim // 2 | ||
|
||
# Final conv | ||
if add_final_conv_layer: | ||
self.final_conv_layer = nn.Conv2d( | ||
n_features, latent_vec_size, kernel_size=4, stride=1, padding=0, bias=False | ||
n_features, | ||
latent_vec_size, | ||
kernel_size=4, | ||
stride=1, | ||
padding=0, | ||
bias=False, | ||
) | ||
|
||
def _compute_padding(self, input_size: Tuple[int, int]) -> Tuple[int, int]: | ||
"""Compute required padding from input size. | ||
|
||
Args: | ||
input_size (Tuple[int, int]): Input size | ||
|
||
Returns: | ||
Tuple[int, int]: Padding for each dimension | ||
""" | ||
# find the largest dimension | ||
l_dim = 2 ** math.ceil(math.log(max(*input_size), 2)) | ||
padding = math.ceil((l_dim - input_size[0]) // 2 + 1), math.ceil((l_dim - input_size[1]) // 2 + 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This won't work for odd image size, e.g. 221 |
||
return padding | ||
|
||
def forward(self, input_tensor: Tensor): | ||
"""Return latent vectors.""" | ||
|
||
|
@@ -98,7 +120,7 @@ class Decoder(nn.Module): | |
"""Decoder Network. | ||
|
||
Args: | ||
input_size (int): Size of input image | ||
input_size (Tuple[int, int]): Size of input image | ||
latent_vec_size (int): Size of latent vector z | ||
num_input_channels (int): Number of input channels in the image | ||
n_features (int): Number of features per convolution layer | ||
|
@@ -107,39 +129,58 @@ class Decoder(nn.Module): | |
""" | ||
|
||
def __init__( | ||
self, input_size: int, latent_vec_size: int, num_input_channels: int, n_features: int, extra_layers: int = 0 | ||
self, | ||
input_size: Tuple[int, int], | ||
latent_vec_size: int, | ||
num_input_channels: int, | ||
n_features: int, | ||
extra_layers: int = 0, | ||
): | ||
super().__init__() | ||
assert input_size % 16 == 0, "Input size should be a multiple of 16" | ||
assert input_size[0] % 16 == 0 and input_size[1] % 16 == 0, "Input size should be a multiple of 16" | ||
|
||
self.latent_input = nn.Sequential() | ||
|
||
# Calculate input channel size to recreate inverse pyramid | ||
exp_factor = int(math.log(input_size // 4, 2)) - 1 | ||
exp_factor = math.ceil(math.log(min(input_size) // 2, 2)) - 2 | ||
n_input_features = n_features * (2**exp_factor) | ||
|
||
# CNN layer for latent vector input | ||
self.latent_input.add_module( | ||
f"initial-{latent_vec_size}-{n_input_features}-convt", | ||
nn.ConvTranspose2d(latent_vec_size, n_input_features, kernel_size=4, stride=1, padding=0, bias=False), | ||
nn.ConvTranspose2d( | ||
latent_vec_size, | ||
n_input_features, | ||
kernel_size=4, | ||
stride=1, | ||
padding=0, | ||
bias=False, | ||
), | ||
) | ||
self.latent_input.add_module(f"initial-{n_input_features}-batchnorm", nn.BatchNorm2d(n_input_features)) | ||
self.latent_input.add_module(f"initial-{n_input_features}-relu", nn.ReLU(True)) | ||
|
||
# Create inverse pyramid | ||
self.inverse_pyramid = nn.Sequential() | ||
input_size = input_size // 2 | ||
while input_size > 4: | ||
pyramid_dim = min(*input_size) // 2 # Use the smaller dimension to create pyramid. | ||
while pyramid_dim > 4: | ||
in_features = n_input_features | ||
out_features = n_input_features // 2 | ||
self.inverse_pyramid.add_module( | ||
f"pyramid-{in_features}-{out_features}-convt", | ||
nn.ConvTranspose2d(in_features, out_features, kernel_size=4, stride=2, padding=1, bias=False), | ||
nn.ConvTranspose2d( | ||
in_features, | ||
out_features, | ||
kernel_size=4, | ||
stride=2, | ||
padding=1, | ||
bias=False, | ||
), | ||
) | ||
self.inverse_pyramid.add_module(f"pyramid-{out_features}-batchnorm", nn.BatchNorm2d(out_features)) | ||
self.inverse_pyramid.add_module(f"pyramid-{out_features}-relu", nn.ReLU(True)) | ||
n_input_features = out_features | ||
input_size = input_size // 2 | ||
pyramid_dim = pyramid_dim // 2 | ||
|
||
# Extra Layers | ||
self.extra_layers = nn.Sequential() | ||
|
@@ -159,7 +200,14 @@ def __init__( | |
self.final_layers = nn.Sequential() | ||
self.final_layers.add_module( | ||
f"final-{n_input_features}-{num_input_channels}-convt", | ||
nn.ConvTranspose2d(n_input_features, num_input_channels, kernel_size=4, stride=2, padding=1, bias=False), | ||
nn.ConvTranspose2d( | ||
n_input_features, | ||
num_input_channels, | ||
kernel_size=4, | ||
stride=2, | ||
padding=1, | ||
bias=False, | ||
), | ||
) | ||
self.final_layers.add_module(f"final-{num_input_channels}-tanh", nn.Tanh()) | ||
|
||
|
@@ -178,13 +226,13 @@ class Discriminator(nn.Module): | |
Made of only one encoder layer which takes x and x_hat to produce a score. | ||
|
||
Args: | ||
input_size (int): Input image size. | ||
input_size (Tuple[int,int]): Input image size. | ||
num_input_channels (int): Number of image channels. | ||
n_features (int): Number of feature maps in each convolution layer. | ||
extra_layers (int, optional): Add extra intermediate layers. Defaults to 0. | ||
""" | ||
|
||
def __init__(self, input_size: int, num_input_channels: int, n_features: int, extra_layers: int = 0): | ||
def __init__(self, input_size: Tuple[int, int], num_input_channels: int, n_features: int, extra_layers: int = 0): | ||
super().__init__() | ||
encoder = Encoder(input_size, 1, num_input_channels, n_features, extra_layers) | ||
layers = [] | ||
|
@@ -212,7 +260,7 @@ class Generator(nn.Module): | |
Made of an encoder-decoder-encoder architecture. | ||
|
||
Args: | ||
input_size (int): Size of input data. | ||
input_size (Tuple[int,int]): Size of input data. | ||
latent_vec_size (int): Dimension of latent vector produced between the first encoder-decoder. | ||
num_input_channels (int): Number of channels in input image. | ||
n_features (int): Number of feature maps in each convolution layer. | ||
|
@@ -222,7 +270,7 @@ class Generator(nn.Module): | |
|
||
def __init__( | ||
self, | ||
input_size: int, | ||
input_size: Tuple[int, int], | ||
latent_vec_size: int, | ||
num_input_channels: int, | ||
n_features: int, | ||
|
@@ -250,7 +298,7 @@ class GanomalyModel(nn.Module): | |
"""Ganomaly Model. | ||
|
||
Args: | ||
input_size (int): Input dimension of a square tensor. | ||
input_size (Tuple[int,int]): Input dimension. | ||
num_input_channels (int): Number of input channels. | ||
n_features (int): Number of features layers in the CNNs. | ||
latent_vec_size (int): Size of autoencoder latent vector. | ||
|
@@ -263,7 +311,7 @@ class GanomalyModel(nn.Module): | |
|
||
def __init__( | ||
self, | ||
input_size: int, | ||
input_size: Tuple[int, int], | ||
num_input_channels: int, | ||
n_features: int, | ||
latent_vec_size: int, | ||
|
@@ -348,7 +396,12 @@ def get_generator_loss(self, images: Tensor) -> Tensor: | |
pred_fake, _ = self.discriminator(fake) | ||
|
||
error_enc = self.loss_enc(latent_i, latent_o) | ||
error_con = self.loss_con(images, fake) | ||
|
||
# Pad input image to match generated image dimension | ||
padding = self.generator.encoder1.padding[::-1] | ||
padded_images = F.pad(images, pad=[padding[0] - 1, padding[0] - 1, padding[1] - 1, padding[1] - 1]) | ||
error_con = self.loss_con(padded_images, fake) | ||
|
||
error_adv = self.loss_adv(pred_real, pred_fake) | ||
|
||
loss_generator = error_adv * self.wadv + error_con * self.wcon + error_enc * self.wenc | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is no longer needed now that we use padding to make sure that the network can process any image size.