Skip to content

Commit

Permalink
Adds patchgan discriminator (#7319)
Browse files Browse the repository at this point in the history
Part of #6676  .

### Description

Adds a patchgan-style discriminator, both single scale and multiscale.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Mark Graham <markgraham539@gmail.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
marksgraham and KumoLiu authored Dec 14, 2023
1 parent 43bc023 commit b85a534
Show file tree
Hide file tree
Showing 4 changed files with 435 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,14 @@ Nets
.. autoclass:: VQVAE
:members:

`PatchGANDiscriminator`
~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: PatchDiscriminator
:members:

.. autoclass:: MultiScalePatchDiscriminator
:members:

Utilities
---------
.. automodule:: monai.networks.utils
Expand Down
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
from .milmodel import MILModel
from .netadapter import NetAdapter
from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator
from .quicknat import Quicknat
from .regressor import Regressor
from .regunet import GlobalNet, LocalNet, RegUNet
Expand Down
247 changes: 247 additions & 0 deletions monai/networks/nets/patchgan_discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence

import torch
import torch.nn as nn

from monai.networks.blocks import Convolution
from monai.networks.layers import Act


class MultiScalePatchDiscriminator(nn.Sequential):
"""
Multi-scale Patch-GAN discriminator based on Pix2PixHD:
High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585)
The Multi-scale discriminator made up of several PatchGAN discriminators, that process the images
at different spatial scales.
Args:
num_d: number of discriminators
num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the first
discriminator. Each subsequent discriminator has one additional layer, meaning the output size is halved.
spatial_dims: number of spatial dimensions (1D, 2D etc.)
channels: number of filters in the first convolutional layer (doubled for each subsequent layer)
in_channels: number of input channels
out_channels: number of output channels in each discriminator
kernel_size: kernel size of the convolution layers
activation: activation layer type
norm: normalisation type
bias: introduction of layer bias
dropout: probability of dropout applied, defaults to 0.
minimum_size_im: minimum spatial size of the input image. Introduced to make sure the architecture
requested isn't going to downsample the input image beyond value of 1.
last_conv_kernel_size: kernel size of the last convolutional layer.
"""

def __init__(
self,
num_d: int,
num_layers_d: int,
spatial_dims: int,
channels: int,
in_channels: int,
out_channels: int = 1,
kernel_size: int = 4,
activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
norm: str | tuple = "BATCH",
bias: bool = False,
dropout: float | tuple = 0.0,
minimum_size_im: int = 256,
last_conv_kernel_size: int = 1,
) -> None:
super().__init__()
self.num_d = num_d
self.num_layers_d = num_layers_d
self.num_channels = channels
self.padding = tuple([int((kernel_size - 1) / 2)] * spatial_dims)
for i_ in range(self.num_d):
num_layers_d_i = self.num_layers_d * (i_ + 1)
output_size = float(minimum_size_im) / (2**num_layers_d_i)
if output_size < 1:
raise AssertionError(
f"Your image size is too small to take in up to {i_} discriminators with num_layers = {num_layers_d_i}."
"Please reduce num_layers, reduce num_D or enter bigger images."
)
subnet_d = PatchDiscriminator(
spatial_dims=spatial_dims,
channels=self.num_channels,
in_channels=in_channels,
out_channels=out_channels,
num_layers_d=num_layers_d_i,
kernel_size=kernel_size,
activation=activation,
norm=norm,
bias=bias,
padding=self.padding,
dropout=dropout,
last_conv_kernel_size=last_conv_kernel_size,
)

self.add_module("discriminator_%d" % i_, subnet_d)

def forward(self, i: torch.Tensor) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]:
"""
Args:
i: Input tensor
Returns:
list of outputs and another list of lists with the intermediate features
of each discriminator.
"""

out: list[torch.Tensor] = []
intermediate_features: list[list[torch.Tensor]] = []
for disc in self.children():
out_d: list[torch.Tensor] = disc(i)
out.append(out_d[-1])
intermediate_features.append(out_d[:-1])

return out, intermediate_features


class PatchDiscriminator(nn.Sequential):
"""
Patch-GAN discriminator based on Pix2PixHD:
High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585)
Args:
spatial_dims: number of spatial dimensions (1D, 2D etc.)
channels: number of filters in the first convolutional layer (doubled for each subsequent layer)
in_channels: number of input channels
out_channels: number of output channels
num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the discriminator.
kernel_size: kernel size of the convolution layers
act: activation type and arguments. Defaults to LeakyReLU.
norm: feature normalization type and arguments. Defaults to batch norm.
bias: whether to have a bias term in convolution blocks. Defaults to False.
padding: padding to be applied to the convolutional layers
dropout: proportion of dropout applied, defaults to 0.
last_conv_kernel_size: kernel size of the last convolutional layer.
"""

def __init__(
self,
spatial_dims: int,
channels: int,
in_channels: int,
out_channels: int = 1,
num_layers_d: int = 3,
kernel_size: int = 4,
activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
norm: str | tuple = "BATCH",
bias: bool = False,
padding: int | Sequence[int] = 1,
dropout: float | tuple = 0.0,
last_conv_kernel_size: int | None = None,
) -> None:
super().__init__()
self.num_layers_d = num_layers_d
self.num_channels = channels
if last_conv_kernel_size is None:
last_conv_kernel_size = kernel_size

self.add_module(
"initial_conv",
Convolution(
spatial_dims=spatial_dims,
kernel_size=kernel_size,
in_channels=in_channels,
out_channels=channels,
act=activation,
bias=True,
norm=None,
dropout=dropout,
padding=padding,
strides=2,
),
)

input_channels = channels
output_channels = channels * 2

# Initial Layer
for l_ in range(self.num_layers_d):
if l_ == self.num_layers_d - 1:
stride = 1
else:
stride = 2
layer = Convolution(
spatial_dims=spatial_dims,
kernel_size=kernel_size,
in_channels=input_channels,
out_channels=output_channels,
act=activation,
bias=bias,
norm=norm,
dropout=dropout,
padding=padding,
strides=stride,
)
self.add_module("%d" % l_, layer)
input_channels = output_channels
output_channels = output_channels * 2

# Final layer
self.add_module(
"final_conv",
Convolution(
spatial_dims=spatial_dims,
kernel_size=last_conv_kernel_size,
in_channels=input_channels,
out_channels=out_channels,
bias=True,
conv_only=True,
padding=int((last_conv_kernel_size - 1) / 2),
dropout=0.0,
strides=1,
),
)

self.apply(self.initialise_weights)

def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""
Args:
x: input tensor
Returns:
list of intermediate features, with the last element being the output.
"""
out = [x]
for submodel in self.children():
intermediate_output = submodel(out[-1])
out.append(intermediate_output)

return out[1:]

def initialise_weights(self, m: nn.Module) -> None:
"""
Initialise weights of Convolution and BatchNorm layers.
Args:
m: instance of torch.nn.module (or of class inheriting torch.nn.module)
"""
classname = m.__class__.__name__
if classname.find("Conv2d") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("Conv3d") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("Conv1d") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
Loading

0 comments on commit b85a534

Please sign in to comment.