From 9f445498db55b4f5116392eb958802f7e5ecece8 Mon Sep 17 00:00:00 2001 From: "vanessagd.2395" Date: Sat, 24 Jun 2023 16:33:29 +0200 Subject: [PATCH] CAMP-TUM network contributions: Quicknat and Daf3D (#6306) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #5921 ### Description @Al3xand1a and @ge96lip implement DAF3D[1] and Quicknat[2] networks and test them in open-source and local datasets. We use as a baseline the pytorch codes available in [3] and [4] We are quite confident about the implementation, but feel free to contact us if you find errors. For any questions send us an email to ge45qix@mytum.de We have some questions for the contribution: For Quicknat: 1) we add the sequential class file in the networks folder because we do not know where to add it. 2) How do we include the squeeze and excitation requirement (sse and Csse) if it comes from a GitHub repository? https://github.com/ai-med/nn-common-modules/releases/download/v1.1/nn_common_modules-1.3-py3-none-any.whl For Daf3D 3) Are the overwritten blocks fine as they are? or do they have to be more flexible? For both: We run this command line (`./runtests.sh --quick --unittests --disttests`) but the error we are getting is not related to our changes, so we run our unit_test independently and they work. Same with documentation. [1] Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound. Yi Wang, Haoran Dou, Xiaowei Hu, Lei Zhu, Xin Yang, Ming Xu, Jing Qin, Pheng-Ann Heng, Tianfu Wang, and Dong Ni. IEEE Transactions on Medical Imaging, 2019. [2] Roy, A. G., Conjeti, S., Navab, N., Wachinger, C., & Alzheimer's Disease Neuroimaging Initiative. (2019). QuickNAT: A fully convolutional network for quick and accurate segmentation of neuroanatomy. NeuroImage, 186, 713-727. [3] https://github.com/ai-med/quickNAT_pytorch [4] https://github.com/wulalago/DAF3D ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: ge96lip <73938628+ge96lip@users.noreply.github.com> Signed-off-by: vanessagd.2395 Signed-off-by: Al3xand1a <98582325+Al3xand1a@users.noreply.github.com> Co-authored-by: Alexandra Marquardt Co-authored-by: Carlotta Co-authored-by: Alexandra Marquardt Co-authored-by: Vanessa Co-authored-by: ge96lip <73938628+ge96lip@users.noreply.github.com> Co-authored-by: “Vanessa <“vanessa.gonzalezduque@ls2n.fr”> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alexandra Marquardt Co-authored-by: Al3xand1a <98582325+Al3xand1a@users.noreply.github.com> Co-authored-by: Alexandra Marquardt --- .gitignore | 4 + monai/networks/blocks/denseblock.py | 2 +- monai/networks/nets/__init__.py | 2 + monai/networks/nets/daf3d.py | 574 ++++++++++++++++++++++++++++ monai/networks/nets/quicknat.py | 439 +++++++++++++++++++++ tests/test_daf3d.py | 62 +++ tests/test_quicknat.py | 57 +++ 7 files changed, 1139 insertions(+), 1 deletion(-) create mode 100644 monai/networks/nets/daf3d.py create mode 100644 monai/networks/nets/quicknat.py create mode 100644 tests/test_daf3d.py create mode 100644 tests/test_quicknat.py diff --git a/.gitignore b/.gitignore index b74875a74d8..c6aea2258b4 100644 --- a/.gitignore +++ b/.gitignore @@ -151,3 +151,7 @@ tests/testing_data/CT_2D_head_moving.mha # profiling results *.prof runs + +*.gz + +*.pth diff --git a/monai/networks/blocks/denseblock.py b/monai/networks/blocks/denseblock.py index ecccab9d5a7..8c67584f5f3 100644 --- a/monai/networks/blocks/denseblock.py +++ b/monai/networks/blocks/denseblock.py @@ -11,7 +11,7 @@ from __future__ import annotations -from collections.abc import Sequence +from typing import Sequence import torch import torch.nn as nn diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 95ddad78420..a0c86281724 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -17,6 +17,7 @@ from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus from .classifier import Classifier, Critic, Discriminator +from .daf3d import DAF3D from .densenet import ( DenseNet, Densenet, @@ -51,6 +52,7 @@ from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet from .milmodel import MILModel from .netadapter import NetAdapter +from .quicknat import Quicknat from .regressor import Regressor from .regunet import GlobalNet, LocalNet, RegUNet from .resnet import ( diff --git a/monai/networks/nets/daf3d.py b/monai/networks/nets/daf3d.py new file mode 100644 index 00000000000..5a83cdc6006 --- /dev/null +++ b/monai/networks/nets/daf3d.py @@ -0,0 +1,574 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from collections import OrderedDict +from collections.abc import Callable, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from monai.networks.blocks import ADN +from monai.networks.blocks.aspp import SimpleASPP +from monai.networks.blocks.backbone_fpn_utils import BackboneWithFPN +from monai.networks.blocks.convolutions import Convolution +from monai.networks.blocks.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork +from monai.networks.layers.factories import Conv, Norm +from monai.networks.nets.resnet import ResNet, ResNetBottleneck + +__all__ = [ + "AttentionModule", + "Daf3dASPP", + "Daf3dResNetBottleneck", + "Daf3dResNetDilatedBottleneck", + "Daf3dResNet", + "Daf3dBackbone", + "Daf3dFPN", + "Daf3dBackboneWithFPN", + "DAF3D", +] + + +class AttentionModule(nn.Module): + """ + Attention Module as described in 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound' + . Returns refined single layer feature (SLF) and attentive map + + Args: + spatial_dims: dimension of inputs. + in_channels: number of input channels (channels of slf and mlf). + out_channels: number of output channels (channels of attentive map and refined slf). + norm: normalization type. + act: activation type. + """ + + def __init__( + self, + spatial_dims, + in_channels, + out_channels, + norm=("group", {"num_groups": 32, "num_channels": 64}), + act="PRELU", + ): + super().__init__() + + self.attentive_map = nn.Sequential( + Convolution(spatial_dims, in_channels, out_channels, kernel_size=1, norm=norm, act=act), + Convolution(spatial_dims, out_channels, out_channels, kernel_size=3, padding=1, norm=norm, act=act), + Convolution( + spatial_dims, out_channels, out_channels, kernel_size=3, padding=1, adn_ordering="A", act="SIGMOID" + ), + ) + self.refine = nn.Sequential( + Convolution(spatial_dims, in_channels, out_channels, kernel_size=1, norm=norm, act=act), + Convolution(spatial_dims, out_channels, out_channels, kernel_size=3, padding=1, norm=norm, act=act), + Convolution(spatial_dims, out_channels, out_channels, kernel_size=3, padding=1, norm=norm, act=act), + ) + + def forward(self, slf, mlf): + att = self.attentive_map(torch.cat((slf, mlf), 1)) + out = self.refine(torch.cat((slf, att * mlf), 1)) + return (out, att) + + +class Daf3dASPP(SimpleASPP): + """ + Atrous Spatial Pyramid Pooling module as used in 'Deep Attentive Features for Prostate Segmentation in + 3D Transrectal Ultrasound' . Core functionality as in SimpleASPP, but after each + layerwise convolution a group normalization is added. Further weight initialization for convolutions is provided in + _init_weight(). Additional possibility to specify the number of final output channels. + + Args: + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. + in_channels: number of input channels. + conv_out_channels: number of output channels of each atrous conv. + out_channels: number of output channels of final convolution. + If None, uses len(kernel_sizes) * conv_out_channels + kernel_sizes: a sequence of four convolutional kernel sizes. + Defaults to (1, 3, 3, 3) for four (dilated) convolutions. + dilations: a sequence of four convolutional dilation parameters. + Defaults to (1, 2, 4, 6) for four (dilated) convolutions. + norm_type: final kernel-size-one convolution normalization type. + Defaults to batch norm. + acti_type: final kernel-size-one convolution activation type. + Defaults to leaky ReLU. + bias: whether to have a bias term in convolution blocks. Defaults to False. + According to `Performance Tuning Guide `_, + if a conv layer is directly followed by a batch norm layer, bias should be False. + + Raises: + ValueError: When ``kernel_sizes`` length differs from ``dilations``. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + conv_out_channels: int, + out_channels: int | None = None, + kernel_sizes: Sequence[int] = (1, 3, 3, 3), + dilations: Sequence[int] = (1, 2, 4, 6), + norm_type: tuple | str | None = "BATCH", + acti_type: tuple | str | None = "LEAKYRELU", + bias: bool = False, + ) -> None: + super().__init__( + spatial_dims, in_channels, conv_out_channels, kernel_sizes, dilations, norm_type, acti_type, bias + ) + + # add normalization after each atrous convolution, initializes weights + new_convs = nn.ModuleList() + for _conv in self.convs: + tmp_conv = Convolution(1, 1, 1) + tmp_conv.conv = _conv + tmp_conv.adn = ADN(ordering="N", norm=norm_type, norm_dim=1) + tmp_conv = self._init_weight(tmp_conv) + new_convs.append(tmp_conv) + self.convs = new_convs + + # change final convolution to different out_channels + if out_channels is None: + out_channels = len(kernel_sizes) * conv_out_channels + + self.conv_k1 = Convolution( + spatial_dims=3, + in_channels=len(kernel_sizes) * conv_out_channels, + out_channels=out_channels, + kernel_size=1, + norm=norm_type, + act=acti_type, + ) + + def _init_weight(self, conv): + for m in conv.modules(): + if isinstance(m, nn.Conv3d): # true for conv.conv + torch.nn.init.kaiming_normal_(m.weight) + return conv + + +class Daf3dResNetBottleneck(ResNetBottleneck): + """ + ResNetBottleneck block as used in 'Deep Attentive Features for Prostate Segmentation in 3D + Transrectal Ultrasound' . + Instead of Batch Norm Group Norm is used, instead of ReLU PReLU activation is used. + Initial expansion is 2 instead of 4 and second convolution uses groups. + + Args: + in_planes: number of input channels. + planes: number of output channels (taking expansion into account). + spatial_dims: number of spatial dimensions of the input image. + stride: stride to use for second conv layer. + downsample: which downsample layer to use. + """ + + expansion = 2 + + def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None): + norm_type: Callable = Norm[Norm.GROUP, spatial_dims] + conv_type: Callable = Conv[Conv.CONV, spatial_dims] + + # in case downsample uses batch norm, change to group norm + if isinstance(downsample, nn.Sequential): + downsample = nn.Sequential( + conv_type(in_planes, planes * self.expansion, kernel_size=1, stride=stride, bias=False), + norm_type(num_groups=32, num_channels=planes * self.expansion), + ) + + super().__init__(in_planes, planes, spatial_dims, stride, downsample) + + # change norm from batch to group norm + self.bn1 = norm_type(num_groups=32, num_channels=planes) + self.bn2 = norm_type(num_groups=32, num_channels=planes) + self.bn3 = norm_type(num_groups=32, num_channels=planes * self.expansion) + + # adapt second convolution to work with groups + self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, stride=stride, groups=32, bias=False) + + # adapt activation function + self.relu = nn.PReLU() # type: ignore + + +class Daf3dResNetDilatedBottleneck(Daf3dResNetBottleneck): + """ + ResNetDilatedBottleneck as used in 'Deep Attentive Features for Prostate Segmentation in 3D + Transrectal Ultrasound' . + Same as Daf3dResNetBottleneck but dilation of 2 is used in second convolution. + Args: + in_planes: number of input channels. + planes: number of output channels (taking expansion into account). + spatial_dims: number of spatial dimensions of the input image. + stride: stride to use for second conv layer. + downsample: which downsample layer to use. + """ + + def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None): + super().__init__(in_planes, planes, spatial_dims, stride, downsample) + + # add dilation in second convolution + conv_type: Callable = Conv[Conv.CONV, spatial_dims] + self.conv2 = conv_type( + planes, planes, kernel_size=3, stride=stride, padding=2, dilation=2, groups=32, bias=False + ) + + +class Daf3dResNet(ResNet): + """ + ResNet as used in 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound' + . + Uses two Daf3dResNetBottleneck blocks followed by two Daf3dResNetDilatedBottleneck blocks. + + Args: + layers: how many layers to use. + block_inplanes: determine the size of planes at each step. Also tunable with widen_factor. + spatial_dims: number of spatial dimensions of the input image. + n_input_channels: number of input channels for first convolutional layer. + conv1_t_size: size of first convolution layer, determines kernel and padding. + conv1_t_stride: stride of first convolution layer. + no_max_pool: bool argument to determine if to use maxpool layer. + shortcut_type: which downsample block to use. Options are 'A', 'B', default to 'B'. + - 'A': using `self._downsample_basic_block`. + - 'B': kernel_size 1 conv + norm. + widen_factor: widen output for each layer. + num_classes: number of output (classifications). + feed_forward: whether to add the FC layer for the output, default to `True`. + bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`. + + """ + + def __init__( + self, + layers: list[int], + block_inplanes: list[int], + spatial_dims: int = 3, + n_input_channels: int = 3, + conv1_t_size: tuple[int] | int = 7, + conv1_t_stride: tuple[int] | int = 1, + no_max_pool: bool = False, + shortcut_type: str = "B", + widen_factor: float = 1.0, + num_classes: int = 400, + feed_forward: bool = True, + bias_downsample: bool = True, # for backwards compatibility (also see PR #5477) + ): + super().__init__( + ResNetBottleneck, + layers, + block_inplanes, + spatial_dims, + n_input_channels, + conv1_t_size, + conv1_t_stride, + no_max_pool, + shortcut_type, + widen_factor, + num_classes, + feed_forward, + bias_downsample, + ) + + self.in_planes = 64 + + conv_type: Callable = Conv[Conv.CONV, spatial_dims] + norm_type: Callable = Norm[Norm.GROUP, spatial_dims] + + # adapt first convolution to work with new in_planes + self.conv1 = conv_type( + n_input_channels, self.in_planes, kernel_size=7, stride=(1, 2, 2), padding=(3, 3, 3), bias=False + ) + self.bn1 = norm_type(32, 64) + self.relu = nn.PReLU() # type: ignore + + # adapt layers to our needs + self.layer1 = self._make_layer(Daf3dResNetBottleneck, block_inplanes[0], layers[0], spatial_dims, shortcut_type) + self.layer2 = self._make_layer( + Daf3dResNetBottleneck, block_inplanes[1], layers[1], spatial_dims, shortcut_type, stride=(1, 2, 2) # type: ignore + ) + self.layer3 = self._make_layer( + Daf3dResNetDilatedBottleneck, block_inplanes[2], layers[2], spatial_dims, shortcut_type, stride=1 + ) + self.layer4 = self._make_layer( + Daf3dResNetDilatedBottleneck, block_inplanes[3], layers[3], spatial_dims, shortcut_type, stride=1 + ) + + +class Daf3dBackbone(nn.Module): + """ + Backbone for 3D Feature Pyramid Network in DAF3D module based on 'Deep Attentive Features for Prostate Segmentation in + 3D Transrectal Ultrasound' . + + Args: + n_input_channels: number of input channels for the first convolution. + """ + + def __init__(self, n_input_channels): + super().__init__() + net = Daf3dResNet( + layers=[3, 4, 6, 3], + block_inplanes=[128, 256, 512, 1024], + n_input_channels=n_input_channels, + num_classes=2, + bias_downsample=False, + ) + net_modules = list(net.children()) + self.layer0 = nn.Sequential(*net_modules[:3]) + self.layer1 = nn.Sequential(*net_modules[3:5]) + self.layer2 = net_modules[5] + self.layer3 = net_modules[6] + self.layer4 = net_modules[7] + + def forward(self, x): + layer0 = self.layer0(x) + layer1 = self.layer1(layer0) + layer2 = self.layer2(layer1) + layer3 = self.layer3(layer2) + layer4 = self.layer4(layer3) + return layer4 + + +class Daf3dFPN(FeaturePyramidNetwork): + """ + Feature Pyramid Network as used in 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound' + . + Omits 3x3x3 convolution of layer_blocks and interpolates resulting feature maps to be the same size as + feature map with highest resolution. + + Args: + spatial_dims: 2D or 3D images + in_channels_list: number of channels for each feature map that is passed to the module + out_channels: number of channels of the FPN representation + extra_blocks: if provided, extra operations will be performed. + It is expected to take the fpn features, the original + features and the names of the original features as input, and returns + a new list of feature maps and their corresponding names + """ + + def __init__( + self, + spatial_dims: int, + in_channels_list: list[int], + out_channels: int, + extra_blocks: ExtraFPNBlock | None = None, + ): + super().__init__(spatial_dims, in_channels_list, out_channels, extra_blocks) + + self.inner_blocks = nn.ModuleList() + for in_channels in in_channels_list: + if in_channels == 0: + raise ValueError("in_channels=0 is currently not supported") + inner_block_module = Convolution( + spatial_dims, + in_channels, + out_channels, + kernel_size=1, + adn_ordering="NA", + act="PRELU", + norm=("group", {"num_groups": 32, "num_channels": 128}), + ) + self.inner_blocks.append(inner_block_module) + + def forward(self, x: dict[str, Tensor]) -> dict[str, Tensor]: + # unpack OrderedDict into two lists for easier handling + names = list(x.keys()) + x_values: list[Tensor] = list(x.values()) + + last_inner = self.get_result_from_inner_blocks(x_values[-1], -1) + results = [] + results.append(last_inner) + + for idx in range(len(x_values) - 2, -1, -1): + inner_lateral = self.get_result_from_inner_blocks(x_values[idx], idx) + feat_shape = inner_lateral.shape[2:] + inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="trilinear") + last_inner = inner_lateral + inner_top_down + results.insert(0, last_inner) + + if self.extra_blocks is not None: + results, names = self.extra_blocks(results, x_values, names) + + # bring all layers to same size + results = [results[0]] + [F.interpolate(l, size=x["feat1"].size()[2:], mode="trilinear") for l in results[1:]] + # make it back an OrderedDict + out = OrderedDict(list(zip(names, results))) + + return out + + +class Daf3dBackboneWithFPN(BackboneWithFPN): + """ + Same as BackboneWithFPN but uses custom Daf3DFPN as feature pyramid network + + Args: + backbone: backbone network + return_layers: a dict containing the names + of the modules for which the activations will be returned as + the key of the dict, and the value of the dict is the name + of the returned activation (which the user can specify). + in_channels_list: number of channels for each feature map + that is returned, in the order they are present in the OrderedDict + out_channels: number of channels in the FPN. + spatial_dims: 2D or 3D images + extra_blocks: if provided, extra operations will + be performed. It is expected to take the fpn features, the original + features and the names of the original features as input, and returns + a new list of feature maps and their corresponding names + """ + + def __init__( + self, + backbone: nn.Module, + return_layers: dict[str, str], + in_channels_list: list[int], + out_channels: int, + spatial_dims: int | None = None, + extra_blocks: ExtraFPNBlock | None = None, + ) -> None: + super().__init__(backbone, return_layers, in_channels_list, out_channels, spatial_dims, extra_blocks) + + if spatial_dims is None: + if hasattr(backbone, "spatial_dims") and isinstance(backbone.spatial_dims, int): + spatial_dims = backbone.spatial_dims + elif isinstance(backbone.conv1, nn.Conv2d): + spatial_dims = 2 + elif isinstance(backbone.conv1, nn.Conv3d): + spatial_dims = 3 + else: + raise ValueError( + "Could not determine value of `spatial_dims` from backbone, please provide explicit value." + ) + + self.fpn = Daf3dFPN(spatial_dims, in_channels_list, out_channels, extra_blocks) + + +class DAF3D(nn.Module): + """ + DAF3D network based on 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound' + . + The network consists of a 3D Feature Pyramid Network which is applied on the feature maps of a 3D ResNet, + followed by a custom Attention Module and an ASPP module. + During training the supervised signal consists of the outputs of the FPN (four Single Layer Features, SLFs), + the outputs of the attention module (four Attentive Features) and the final prediction. + They are individually compared to the ground truth, the final loss consists of a weighted sum of all + individual losses (see DAF3D tutorial for details). + There is an additional possiblity to return all supervised signals as well as the Attentive Maps in validation + mode to visualize inner functionality of the network. + + Args: + in_channels: number of input channels. + out_channels: number of output channels. + visual_output: whether to return all SLFs, Attentive Maps, Refined SLFs in validation mode + can be used to visualize inner functionality of the network + """ + + def __init__(self, in_channels, out_channels, visual_output=False): + super().__init__() + self.visual_output = visual_output + self.backbone_with_fpn = Daf3dBackboneWithFPN( + backbone=Daf3dBackbone(in_channels), + return_layers={"layer1": "feat1", "layer2": "feat2", "layer3": "feat3", "layer4": "feat4"}, + in_channels_list=[256, 512, 1024, 2048], + out_channels=128, + spatial_dims=3, + ) + self.predict1 = nn.Conv3d(128, out_channels, kernel_size=1) + + group_norm = ("group", {"num_groups": 32, "num_channels": 64}) + act_prelu = ("prelu", {"num_parameters": 1, "init": 0.25}) + self.fuse = nn.Sequential( + Convolution( + spatial_dims=3, + in_channels=512, + out_channels=64, + kernel_size=1, + adn_ordering="NA", + norm=group_norm, + act=act_prelu, + ), + Convolution( + spatial_dims=3, + in_channels=64, + out_channels=64, + kernel_size=3, + adn_ordering="NA", + padding=1, + norm=group_norm, + act=act_prelu, + ), + Convolution( + spatial_dims=3, + in_channels=64, + out_channels=64, + kernel_size=3, + adn_ordering="NA", + padding=1, + norm=group_norm, + act=act_prelu, + ), + ) + self.attention = AttentionModule( + spatial_dims=3, in_channels=192, out_channels=64, norm=group_norm, act=act_prelu + ) + + self.refine = Convolution(3, 256, 64, kernel_size=1, adn_ordering="NA", norm=group_norm, act=act_prelu) + self.predict2 = nn.Conv3d(64, out_channels, kernel_size=1) + self.aspp = Daf3dASPP( + spatial_dims=3, + in_channels=64, + conv_out_channels=64, + out_channels=64, + kernel_sizes=(3, 3, 3, 3), + dilations=((1, 1, 1), (1, 6, 6), (1, 12, 12), (1, 18, 18)), # type: ignore + norm_type=group_norm, + acti_type=None, + bias=True, + ) + + def forward(self, x): + # layers from 1 - 4 + single_layer_features = list(self.backbone_with_fpn(x).values()) + + # first 4 supervised signals (SLFs 1 - 4) + supervised1 = [self.predict1(slf) for slf in single_layer_features] + + mlf = self.fuse(torch.cat(single_layer_features, 1)) + + attentive_features_maps = [self.attention(slf, mlf) for slf in single_layer_features] + att_features, att_maps = tuple(zip(*attentive_features_maps)) + + # second 4 supervised signals (af 1 - 4) + supervised2 = [self.predict2(af) for af in att_features] + + # attentive maps as optional additional output + supervised3 = [self.predict2(am) for am in att_maps] + + attentive_mlf = self.refine(torch.cat(att_features, 1)) + + aspp = self.aspp(attentive_mlf) + + supervised_final = self.predict2(aspp) + + if self.training: + output = supervised1 + supervised2 + [supervised_final] + output = [F.interpolate(o, size=x.size()[2:], mode="trilinear") for o in output] + else: + if self.visual_output: + supervised_final = F.interpolate(supervised_final, size=x.size()[2:], mode="trilinear") + supervised_inner = [ + F.interpolate(o, size=x.size()[2:], mode="trilinear") + for o in supervised1 + supervised2 + supervised3 + ] + output = [supervised_final] + supervised_inner + else: + output = F.interpolate(supervised_final, size=x.size()[2:], mode="trilinear") + return output diff --git a/monai/networks/nets/quicknat.py b/monai/networks/nets/quicknat.py new file mode 100644 index 00000000000..cbcccf24d71 --- /dev/null +++ b/monai/networks/nets/quicknat.py @@ -0,0 +1,439 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import ConvDenseBlock, Convolution +from monai.networks.blocks import squeeze_and_excitation as se +from monai.networks.layers.factories import Act, Norm +from monai.networks.layers.simplelayers import SkipConnection +from monai.networks.layers.utils import get_dropout_layer, get_pool_layer +from monai.utils import optional_import + +# Lazy import to avoid dependency +se1, flag = optional_import("squeeze_and_excitation") + +__all__ = ["Quicknat"] + +# QuickNAT specific Blocks + + +class SkipConnectionWithIdx(SkipConnection): + """ + Combine the forward pass input with the result from the given submodule:: + --+--submodule--o-- + |_____________| + The available modes are ``"cat"``, ``"add"``, ``"mul"``. + Defaults to "cat" and dimension 1. + Inherits from SkipConnection but provides the indizes with each forward pass. + """ + + def forward(self, input, indices): + return super().forward(input), indices + + +class SequentialWithIdx(nn.Sequential): + """ + A sequential container. + Modules will be added to it in the order they are passed in the + constructor. + Own implementation to work with the new indices in the forward pass. + """ + + def __init__(self, *args): + super().__init__(*args) + + def forward(self, input, indices): + for module in self: + input, indices = module(input, indices) + return input, indices + + +class ClassifierBlock(Convolution): + """ + Returns a classifier block without an activation function at the top. + It consists of a 1 * 1 convolutional layer which maps the input to a num_class channel feature map. + The output is a probability map for each of the classes. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of classes to map to. + strides: convolution stride. Defaults to 1. + kernel_size: convolution kernel size. Defaults to 3. + adn_ordering: a string representing the ordering of activation, normalization, and dropout. + Defaults to "NDA". + act: activation type and arguments. Defaults to PReLU. + + """ + + def __init__(self, spatial_dims, in_channels, out_channels, strides, kernel_size, act=None, adn_ordering="A"): + super().__init__(spatial_dims, in_channels, out_channels, strides, kernel_size, adn_ordering, act) + + def forward(self, input: torch.Tensor, weights=None, indices=None): + _, channel, *dims = input.size() + if weights is not None: + weights, _ = torch.max(weights, dim=0) + weights = weights.view(1, channel, 1, 1) + # use weights to adapt how the classes are weighted. + if len(dims) == 2: + out_conv = F.conv2d(input, weights) + else: + raise ValueError("Quicknat is a 2D architecture, please check your dimension.") + else: + out_conv = super().forward(input) + # no indices to return + return out_conv, None + + +# Quicknat specific blocks. All blocks inherit from MONAI blocks but have adaptions to their structure +class ConvConcatDenseBlock(ConvDenseBlock): + """ + This dense block is defined as a sequence of 'Convolution' blocks. It overwrite the '_get_layer' methodto change the ordering of + Every convolutional layer is preceded by a batch-normalization layer and a Rectifier Linear Unit (ReLU) layer. + The first two convolutional layers are followed by a concatenation layer that concatenates + the input feature map with outputs of the current and previous convolutional blocks. + Kernel size of two convolutional layers kept small to limit number of paramters. + Appropriate padding is provided so that the size of feature maps before and after convolution remains constant. + The output channels for each convolution layer is set to 64, which acts as a bottle- neck for feature map selectivity. + The input channel size is variable, depending on the number of dense connections. + The third convolutional layer is also preceded by a batch normalization and ReLU, + but has a 1 * 1 kernel size to compress the feature map size to 64. + Args: + in_channles: variable depending on depth of the network + seLayer: Squeeze and Excite block to be included, defaults to None, valid options are {'NONE', 'CSE', 'SSE', 'CSSE'}, + dropout_layer: Dropout block to be included, defaults to None. + :return: forward passed tensor + """ + + def __init__( + self, + in_channels: int, + se_layer: Optional[nn.Module] = None, + dropout_layer: Optional[nn.Dropout2d] = None, + kernel_size: Sequence[int] | int = 5, + num_filters: int = 64, + ): + self.count = 0 + super().__init__( + in_channels=in_channels, + spatial_dims=2, + # number of channels stay constant throughout the convolution layers + channels=[num_filters, num_filters, num_filters], + norm=("instance", {"num_features": in_channels}), + kernel_size=kernel_size, + ) + self.se_layer = se_layer if se_layer is not None else nn.Identity() + self.dropout_layer = dropout_layer if dropout_layer is not None else nn.Identity() + + def _get_layer(self, in_channels, out_channels, dilation): + """ + After ever convolutional layer the output is concatenated with the input and the layer before. + The concatenated output is used as input to the next convolutional layer. + + Args: + in_channels: number of input channels. + out_channels: number of output channels. + strides: convolution stride. + is_top: True if this is the top block. + """ + kernelsize = self.kernel_size if self.count < 2 else (1, 1) + # padding = None if self.count < 2 else (0, 0) + self.count += 1 + conv = Convolution( + spatial_dims=self.spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + strides=1, + kernel_size=kernelsize, + act=self.act, + norm=("instance", {"num_features": in_channels}), + ) + return nn.Sequential(conv.get_submodule("adn"), conv.get_submodule("conv")) + + def forward(self, input, _): + i = 0 + result = input + for l in self.children(): + # ignoring the max (un-)pool and droupout already added in the initial initialization step + if isinstance(l, (nn.MaxPool2d, nn.MaxUnpool2d, nn.Dropout2d)): + continue + # first convolutional forward + result = l(result) + if i == 0: + result1 = result + # concatenation with the input feature map + result = torch.cat((input, result), dim=1) + + if i == 1: + # concatenation with input feature map and feature map from first convolution + result = torch.cat((result1, result, input), dim=1) + i = i + 1 + + # if SELayer or Dropout layer defined put output through layer before returning, + # else it just goes through nn.Identity and the output does not change + result = self.se_layer(result) + result = self.dropout_layer(result) + + return result, None + + +class Encoder(ConvConcatDenseBlock): + """ + Returns a convolution dense block for the encoding (down) part of a layer of the network. + This Encoder block downpools the data with max_pool. + Its output is used as input to the next layer down. + New feature: it returns the indices of the max_pool to the decoder (up) path + at the same layer to upsample the input. + + Args: + in_channels: number of input channels. + max_pool: predefined max_pool layer to downsample the data. + se_layer: Squeeze and Excite block to be included, defaults to None. + dropout: Dropout block to be included, defaults to None. + kernel_size : kernel size of the convolutional layers. Defaults to 5*5 + num_filters : number of input channels to each convolution block. Defaults to 64 + """ + + def __init__(self, in_channels: int, max_pool, se_layer, dropout, kernel_size, num_filters): + super().__init__(in_channels, se_layer, dropout, kernel_size, num_filters) + self.max_pool = max_pool + + def forward(self, input, indices=None): + input, indices = self.max_pool(input) + + out_block, _ = super().forward(input, None) + # safe the indices for unpool on decoder side + return out_block, indices + + +class Decoder(ConvConcatDenseBlock): + """ + Returns a convolution dense block for the decoding (up) part of a layer of the network. + This will upsample data with an unpool block before the forward. + It uses the indices from corresponding encoder on it's level. + Its output is used as input to the next layer up. + + Args: + in_channels: number of input channels. + un_pool: predefined unpool block. + se_layer: predefined SELayer. Defaults to None. + dropout: predefined dropout block. Defaults to None. + kernel_size: Kernel size of convolution layers. Defaults to 5*5. + num_filters: number of input channels to each convolution layer. Defaults to 64. + """ + + def __init__(self, in_channels: int, un_pool, se_layer, dropout, kernel_size, num_filters): + super().__init__(in_channels, se_layer, dropout, kernel_size, num_filters) + self.un_pool = un_pool + + def forward(self, input, indices): + out_block, _ = super().forward(input, None) + out_block = self.un_pool(out_block, indices) + return out_block, None + + +class Bottleneck(ConvConcatDenseBlock): + """ + Returns the bottom or bottleneck layer at the bottom of a network linking encoder to decoder halves. + It consists of a 5 * 5 convolutional layer and a batch normalization layer to separate + the encoder and decoder part of the network, restricting information flow between the encoder and decoder. + + Args: + in_channels: number of input channels. + se_layer: predefined SELayer. Defaults to None. + dropout: predefined dropout block. Defaults to None. + un_pool: predefined unpool block. + max_pool: predefined maxpool block. + kernel_size: Kernel size of convolution layers. Defaults to 5*5. + num_filters: number of input channels to each convolution layer. Defaults to 64. + """ + + def __init__(self, in_channels: int, se_layer, dropout, max_pool, un_pool, kernel_size, num_filters): + super().__init__(in_channels, se_layer, dropout, kernel_size, num_filters) + self.max_pool = max_pool + self.un_pool = un_pool + + def forward(self, input, indices): + out_block, indices = self.max_pool(input) + out_block, _ = super().forward(out_block, None) + out_block = self.un_pool(out_block, indices) + return out_block, None + + +class Quicknat(nn.Module): + """ + Model for "Quick segmentation of NeuroAnaTomy (QuickNAT) based on a deep fully convolutional neural network. + Refer to: "QuickNAT: A Fully Convolutional Network for Quick and Accurate Segmentation of Neuroanatomy by + Abhijit Guha Roya, Sailesh Conjetib, Nassir Navabb, Christian Wachingera" + + QuickNAT has an encoder/decoder like 2D F-CNN architecture with 4 encoders and 4 decoders separated by a bottleneck layer. + The final layer is a classifier block with softmax. + The architecture includes skip connections between all encoder and decoder blocks of the same spatial resolution, + similar to the U-Net architecture. + All Encoder and Decoder consist of three convolutional layers all with a Batch Normalization and ReLU. + The first two convolutional layers are followed by a concatenation layer that concatenates + the input feature map with outputs of the current and previous convolutional blocks. + The kernel size of the first two convolutional layers is 5*5, the third convolutional layer has a kernel size of 1*1. + + Data in the encode path is downsampled using max pooling layers instead of upsamling like UNet and in the decode path + upsampled using max un-pooling layers instead of transpose convolutions. + The pooling is done at the beginning of the block and the unpool afterwards. + The indices of the max pooling in the Encoder are forwarded through the layer to be available to the corresponding Decoder. + + The bottleneck block consists of a 5 * 5 convolutional layer and a batch normalization layer + to separate the encoder and decoder part of the network, + restricting information flow between the encoder and decoder. + + The output feature map from the last decoder block is passed to the classifier block, + which is a convolutional layer with 1 * 1 kernel size that maps the input to an N channel feature map, + where N is the number of segmentation classes. + + To further explain this consider the first example network given below. This network has 3 layers with strides + of 2 for each of the middle layers (the last layer is the bottom connection which does not down/up sample). Input + data to this network is immediately reduced in the spatial dimensions by a factor of 2 by the first convolution of + the residual unit defining the first layer of the encode part. The last layer of the decode part will upsample its + input (data from the previous layer concatenated with data from the skip connection) in the first convolution. this + ensures the final output of the network has the same shape as the input. + + The original QuickNAT implementation included a `enable_test_dropout()` mechanism for uncertainty estimation during + testing. As the dropout layers are the only stochastic components of this network calling the train() method instead + of eval() in testing or inference has the same effect. + + Args: + num_classes: number of classes to segmentate (output channels). + num_channels: number of input channels. + num_filters: number of output channels for each convolutional layer in a Dense Block. + kernel_size: size of the kernel of each convolutional layer in a Dense Block. + kernel_c: convolution kernel size of classifier block kernel. + stride_convolution: convolution stride. Defaults to 1. + pool: kernel size of the pooling layer, + stride_pool: stride for the pooling layer. + se_block: Squeeze and Excite block type to be included, defaults to None. Valid options : NONE, CSE, SSE, CSSE, + droup_out: dropout ratio. Defaults to no dropout. + act: activation type and arguments. Defaults to PReLU. + norm: feature normalization type and arguments. Defaults to instance norm. + adn_ordering: a string representing the ordering of activation (A), normalization (N), and dropout (D). + Defaults to "NA". See also: :py:class:`monai.networks.blocks.ADN`. + + Examples:: + + from monai.networks.nets import QuickNAT + + # network with max pooling by a factor of 2 at each layer with no se_block. + net = QuickNAT( + num_classes=3, + num_channels=1, + num_filters=64, + pool = 2, + se_block = "None" + ) + + """ + + def __init__( + self, + num_classes: int = 33, + num_channels: int = 1, + num_filters: int = 64, + kernel_size: Sequence[int] | int = 5, + kernel_c: int = 1, + stride_conv: int = 1, + pool: int = 2, + stride_pool: int = 2, + # Valid options : NONE, CSE, SSE, CSSE + se_block: str = "None", + drop_out: float = 0, + act: Union[Tuple, str] = Act.PRELU, + norm: Union[Tuple, str] = Norm.INSTANCE, + adn_ordering: str = "NA", + ) -> None: + self.act = act + self.norm = norm + self.adn_ordering = adn_ordering + super().__init__() + se_layer = self.get_selayer(num_filters, se_block) + dropout_layer = get_dropout_layer(name=("dropout", {"p": drop_out}), dropout_dim=2) + max_pool = get_pool_layer( + name=("max", {"kernel_size": pool, "stride": stride_pool, "return_indices": True, "ceil_mode": True}), + spatial_dims=2, + ) + # for the unpooling layer there is currently no Monai implementation available, return to torch implementation + un_pool = nn.MaxUnpool2d(kernel_size=pool, stride=stride_pool) + + # sequence of convolutional strides (like in UNet) not needed as they are always stride_conv. This defaults to 1. + def _create_model(layer: int) -> nn.Module: + """ + Builds the QuickNAT structure from the bottom up by recursing down to the bottelneck layer, then creating sequential + blocks containing the decoder, a skip connection around the previous block, and the encoder. + At the last layer a classifier block is added to the Sequential. + + Args: + layer = inversproportional to the layers left to create + """ + subblock: nn.Module + if layer < 4: + subblock = _create_model(layer + 1) + + else: + subblock = Bottleneck(num_filters, se_layer, dropout_layer, max_pool, un_pool, kernel_size, num_filters) + + if layer == 1: + down = ConvConcatDenseBlock(num_channels, se_layer, dropout_layer, kernel_size, num_filters) + up = ConvConcatDenseBlock(num_filters * 2, se_layer, dropout_layer, kernel_size, num_filters) + classifier = ClassifierBlock(2, num_filters, num_classes, stride_conv, kernel_c) + return SequentialWithIdx(down, SkipConnectionWithIdx(subblock), up, classifier) + else: + up = Decoder(num_filters * 2, un_pool, se_layer, dropout_layer, kernel_size, num_filters) + down = Encoder(num_filters, max_pool, se_layer, dropout_layer, kernel_size, num_filters) + return SequentialWithIdx(down, SkipConnectionWithIdx(subblock), up) + + self.model = _create_model(1) + + def get_selayer(self, n_filters, se_block_type="None"): + """ + Returns the SEBlock defined in the initialization of the QuickNAT model. + + Args: + n_filters: encoding half of the layer + se_block_type: defaults to None. Valid options are None, CSE, SSE, CSSE + Returns: Appropriate SEBlock. SSE and CSSE not implemented in Monai yet. + """ + if se_block_type == "CSE": + return se.ChannelSELayer(2, n_filters) + # not implemented in squeeze_and_excitation in monai use other squeeze_and_excitation import: + elif se_block_type == "SSE" or se_block_type == "CSSE": + # Throw error if squeeze_and_excitation is not installed + if not flag: + raise ImportError("Please install squeeze_and_excitation locally to use SpatialSELayer") + if se_block_type == "SSE": + return se1.SpatialSELayer(n_filters) + else: + return se1.ChannelSpatialSELayer(n_filters) + else: + return None + + @property + def is_cuda(self): + """ + Check if model parameters are allocated on the GPU. + """ + return next(self.parameters()).is_cuda + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input, _ = self.model(input, None) + return input diff --git a/tests/test_daf3d.py b/tests/test_daf3d.py new file mode 100644 index 00000000000..34e25cc6bef --- /dev/null +++ b/tests/test_daf3d.py @@ -0,0 +1,62 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import DAF3D +from monai.utils import optional_import +from tests.utils import test_script_save + +_, has_tv = optional_import("torchvision") + +TEST_CASES = [ + [{"in_channels": 1, "out_channels": 1}, (1, 1, 32, 32, 64), (1, 1, 32, 32, 64)], # single channel 3D, batch 1 + [{"in_channels": 2, "out_channels": 1}, (3, 2, 32, 64, 128), (3, 1, 32, 64, 128)], # two channel 3D, batch 3 + [ + {"in_channels": 2, "out_channels": 2}, + (3, 2, 32, 64, 128), + (3, 2, 32, 64, 128), + ], # two channel 3D, same in & out channels + [{"in_channels": 4, "out_channels": 1}, (5, 4, 35, 35, 35), (5, 1, 35, 35, 35)], # four channel 3D, batch 5 + [ + {"in_channels": 4, "out_channels": 4}, + (5, 4, 35, 35, 35), + (5, 4, 35, 35, 35), + ], # four channel 3D, same in & out channels +] + + +@unittest.skipUnless(has_tv, "torchvision not installed") +class TestDAF3D(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + print(input_param) + net = DAF3D(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + @unittest.skip("daf3d: torchscript not currently supported") + def test_script(self): + net = DAF3D(in_channels=1, out_channels=1) + test_data = torch.randn(16, 1, 32, 32) + test_script_save(net, test_data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_quicknat.py b/tests/test_quicknat.py new file mode 100644 index 00000000000..b4b89b7d624 --- /dev/null +++ b/tests/test_quicknat.py @@ -0,0 +1,57 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import Quicknat +from monai.utils import optional_import +from tests.utils import test_script_save + +_, has_se = optional_import("squeeze_and_excitation") + +TEST_CASES = [ + # params, input_shape, expected_shape + [{"num_classes": 1, "num_channels": 1, "num_filters": 1, "se_block": None}, (1, 1, 32, 32), (1, 1, 32, 32)], + [{"num_classes": 1, "num_channels": 1, "num_filters": 4, "se_block": None}, (1, 1, 64, 64), (1, 1, 64, 64)], + [{"num_classes": 1, "num_channels": 1, "num_filters": 64, "se_block": None}, (1, 1, 128, 128), (1, 1, 128, 128)], + [{"num_classes": 4, "num_channels": 1, "num_filters": 64, "se_block": None}, (1, 1, 32, 32), (1, 4, 32, 32)], + [{"num_classes": 33, "num_channels": 1, "num_filters": 64, "se_block": None}, (1, 1, 32, 32), (1, 33, 32, 32)], + [{"num_classes": 1, "num_channels": 1, "num_filters": 64, "se_block": "CSE"}, (1, 1, 32, 32), (1, 1, 32, 32)], + [{"num_classes": 1, "num_channels": 1, "num_filters": 64, "se_block": "SSE"}, (1, 1, 32, 32), (1, 1, 32, 32)], + [{"num_classes": 1, "num_channels": 1, "num_filters": 64, "se_block": "CSSE"}, (1, 1, 32, 32), (1, 1, 32, 32)], +] + + +@unittest.skipUnless(has_se, "squeeze_and_excitation not installed") +class TestQuicknat(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + print(input_param) + net = Quicknat(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_script(self): + net = Quicknat(num_classes=1, num_channels=1) + test_data = torch.randn(16, 1, 32, 32) + test_script_save(net, test_data) + + +if __name__ == "__main__": + unittest.main()