diff --git a/monai/networks/nets/daf3d.py b/monai/networks/nets/daf3d.py index 31c2bf4c31..02e5bb022a 100644 --- a/monai/networks/nets/daf3d.py +++ b/monai/networks/nets/daf3d.py @@ -13,6 +13,7 @@ from collections import OrderedDict from collections.abc import Callable, Sequence +from functools import partial import torch import torch.nn as nn @@ -25,6 +26,7 @@ from monai.networks.blocks.convolutions import Convolution from monai.networks.blocks.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork from monai.networks.layers.factories import Conv, Norm +from monai.networks.layers.utils import get_norm_layer from monai.networks.nets.resnet import ResNet, ResNetBottleneck __all__ = [ @@ -170,27 +172,31 @@ class Daf3dResNetBottleneck(ResNetBottleneck): spatial_dims: number of spatial dimensions of the input image. stride: stride to use for second conv layer. downsample: which downsample layer to use. + norm: which normalization layer to use. Defaults to group. """ expansion = 2 - def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None): - norm_type: Callable = Norm[Norm.GROUP, spatial_dims] + def __init__( + self, in_planes, planes, spatial_dims=3, stride=1, downsample=None, norm=("group", {"num_groups": 32}) + ): conv_type: Callable = Conv[Conv.CONV, spatial_dims] + norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims) + # in case downsample uses batch norm, change to group norm if isinstance(downsample, nn.Sequential): downsample = nn.Sequential( conv_type(in_planes, planes * self.expansion, kernel_size=1, stride=stride, bias=False), - norm_type(num_groups=32, num_channels=planes * self.expansion), + norm_layer(channels=planes * self.expansion), ) super().__init__(in_planes, planes, spatial_dims, stride, downsample) # change norm from batch to group norm - self.bn1 = norm_type(num_groups=32, num_channels=planes) - self.bn2 = norm_type(num_groups=32, num_channels=planes) - self.bn3 = norm_type(num_groups=32, num_channels=planes * self.expansion) + self.bn1 = norm_layer(channels=planes) + self.bn2 = norm_layer(channels=planes) + self.bn3 = norm_layer(channels=planes * self.expansion) # adapt second convolution to work with groups self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, stride=stride, groups=32, bias=False) @@ -212,8 +218,10 @@ class Daf3dResNetDilatedBottleneck(Daf3dResNetBottleneck): downsample: which downsample layer to use. """ - def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None): - super().__init__(in_planes, planes, spatial_dims, stride, downsample) + def __init__( + self, in_planes, planes, spatial_dims=3, stride=1, downsample=None, norm=("group", {"num_groups": 32}) + ): + super().__init__(in_planes, planes, spatial_dims, stride, downsample, norm) # add dilation in second convolution conv_type: Callable = Conv[Conv.CONV, spatial_dims] diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 93a77d7b2a..2cd7c8102a 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -22,8 +22,8 @@ import torch.nn as nn from monai.networks.blocks.encoder import BaseEncoder -from monai.networks.layers.factories import Conv, Norm, Pool -from monai.networks.layers.utils import get_act_layer, get_pool_layer +from monai.networks.layers.factories import Conv, Pool +from monai.networks.layers.utils import get_act_layer, get_norm_layer, get_pool_layer from monai.utils import ensure_tuple_rep from monai.utils.module import look_up_option, optional_import @@ -79,6 +79,7 @@ def __init__( stride: int = 1, downsample: nn.Module | partial | None = None, act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", ) -> None: """ Args: @@ -88,17 +89,18 @@ def __init__( stride: stride to use for first conv layer. downsample: which downsample layer to use. act: activation type and arguments. Defaults to relu. + norm: feature normalization type and arguments. Defaults to batch norm. """ super().__init__() conv_type: Callable = Conv[Conv.CONV, spatial_dims] - norm_type: Callable = Norm[Norm.BATCH, spatial_dims] + norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes) self.conv1 = conv_type(in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False) - self.bn1 = norm_type(planes) + self.bn1 = norm_layer self.act = get_act_layer(name=act) self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False) - self.bn2 = norm_type(planes) + self.bn2 = norm_layer self.downsample = downsample self.stride = stride @@ -132,6 +134,7 @@ def __init__( stride: int = 1, downsample: nn.Module | partial | None = None, act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", ) -> None: """ Args: @@ -141,19 +144,20 @@ def __init__( stride: stride to use for second conv layer. downsample: which downsample layer to use. act: activation type and arguments. Defaults to relu. + norm: feature normalization type and arguments. Defaults to batch norm. """ super().__init__() conv_type: Callable = Conv[Conv.CONV, spatial_dims] - norm_type: Callable = Norm[Norm.BATCH, spatial_dims] + norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims) self.conv1 = conv_type(in_planes, planes, kernel_size=1, bias=False) - self.bn1 = norm_type(planes) + self.bn1 = norm_layer(channels=planes) self.conv2 = conv_type(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn2 = norm_type(planes) + self.bn2 = norm_layer(channels=planes) self.conv3 = conv_type(planes, planes * self.expansion, kernel_size=1, bias=False) - self.bn3 = norm_type(planes * self.expansion) + self.bn3 = norm_layer(channels=planes * self.expansion) self.act = get_act_layer(name=act) self.downsample = downsample self.stride = stride @@ -207,6 +211,7 @@ class ResNet(nn.Module): feed_forward: whether to add the FC layer for the output, default to `True`. bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`. act: activation type and arguments. Defaults to relu. + norm: feature normalization type and arguments. Defaults to batch norm. """ @@ -226,6 +231,7 @@ def __init__( feed_forward: bool = True, bias_downsample: bool = True, # for backwards compatibility (also see PR #5477) act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", ) -> None: super().__init__() @@ -238,7 +244,6 @@ def __init__( raise ValueError("Unknown block '%s', use basic or bottleneck" % block) conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims] - norm_type: type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims] pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims] avgp_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[ Pool.ADAPTIVEAVG, spatial_dims @@ -262,7 +267,9 @@ def __init__( padding=tuple(k // 2 for k in conv1_kernel_size), bias=False, ) - self.bn1 = norm_type(self.in_planes) + + norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=self.in_planes) + self.bn1 = norm_layer self.act = get_act_layer(name=act) self.maxpool = pool_type(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, block_inplanes[0], layers[0], spatial_dims, shortcut_type) @@ -275,7 +282,7 @@ def __init__( for m in self.modules(): if isinstance(m, conv_type): nn.init.kaiming_normal_(torch.as_tensor(m.weight), mode="fan_out", nonlinearity="relu") - elif isinstance(m, norm_type): + elif isinstance(m, type(norm_layer)): nn.init.constant_(torch.as_tensor(m.weight), 1) nn.init.constant_(torch.as_tensor(m.bias), 0) elif isinstance(m, nn.Linear): @@ -295,9 +302,9 @@ def _make_layer( spatial_dims: int, shortcut_type: str, stride: int = 1, + norm: str | tuple = "batch", ) -> nn.Sequential: conv_type: Callable = Conv[Conv.CONV, spatial_dims] - norm_type: Callable = Norm[Norm.BATCH, spatial_dims] downsample: nn.Module | partial | None = None if stride != 1 or self.in_planes != planes * block.expansion: @@ -317,18 +324,23 @@ def _make_layer( stride=stride, bias=self.bias_downsample, ), - norm_type(planes * block.expansion), + get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes * block.expansion), ) layers = [ block( - in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample + in_planes=self.in_planes, + planes=planes, + spatial_dims=spatial_dims, + stride=stride, + downsample=downsample, + norm=norm, ) ] self.in_planes = planes * block.expansion for _i in range(1, blocks): - layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims)) + layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims, norm=norm)) return nn.Sequential(*layers) diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 3a58d1c955..e873f1238a 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -202,13 +202,30 @@ (1, 3), ] +TEST_CASE_9 = [ # Layer norm + { + "block": ResNetBlock, + "layers": [3, 4, 6, 3], + "block_inplanes": [64, 128, 256, 512], + "spatial_dims": 1, + "n_input_channels": 2, + "num_classes": 3, + "conv1_t_size": [3], + "conv1_t_stride": 1, + "act": ("relu", {"inplace": False}), + "norm": ("layer", {"normalized_shape": (64, 32)}), + }, + (1, 2, 32), + (1, 3), +] + TEST_CASES = [] PRETRAINED_TEST_CASES = [] for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]: for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: TEST_CASES.append([model, *case]) PRETRAINED_TEST_CASES.append([model, *case]) -for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]: +for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]: TEST_CASES.append([ResNet, *case]) TEST_SCRIPT_CASES = [