diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 93a77d7b2a..cd07485a09 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -22,8 +22,8 @@ import torch.nn as nn from monai.networks.blocks.encoder import BaseEncoder -from monai.networks.layers.factories import Conv, Norm, Pool -from monai.networks.layers.utils import get_act_layer, get_pool_layer +from monai.networks.layers.factories import Conv, Pool +from monai.networks.layers.utils import get_act_layer, get_norm_layer, get_pool_layer from monai.utils import ensure_tuple_rep from monai.utils.module import look_up_option, optional_import @@ -79,6 +79,7 @@ def __init__( stride: int = 1, downsample: nn.Module | partial | None = None, act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", ) -> None: """ Args: @@ -92,13 +93,13 @@ def __init__( super().__init__() conv_type: Callable = Conv[Conv.CONV, spatial_dims] - norm_type: Callable = Norm[Norm.BATCH, spatial_dims] + norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes) self.conv1 = conv_type(in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False) - self.bn1 = norm_type(planes) + self.bn1 = norm_layer self.act = get_act_layer(name=act) self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False) - self.bn2 = norm_type(planes) + self.bn2 = norm_layer self.downsample = downsample self.stride = stride @@ -132,6 +133,7 @@ def __init__( stride: int = 1, downsample: nn.Module | partial | None = None, act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", ) -> None: """ Args: @@ -146,14 +148,14 @@ def __init__( super().__init__() conv_type: Callable = Conv[Conv.CONV, spatial_dims] - norm_type: Callable = Norm[Norm.BATCH, spatial_dims] + norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims) self.conv1 = conv_type(in_planes, planes, kernel_size=1, bias=False) - self.bn1 = norm_type(planes) + self.bn1 = norm_layer(channels=planes) self.conv2 = conv_type(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn2 = norm_type(planes) + self.bn2 = norm_layer(channels=planes) self.conv3 = conv_type(planes, planes * self.expansion, kernel_size=1, bias=False) - self.bn3 = norm_type(planes * self.expansion) + self.bn3 = norm_layer(channels=planes * self.expansion) self.act = get_act_layer(name=act) self.downsample = downsample self.stride = stride @@ -226,6 +228,7 @@ def __init__( feed_forward: bool = True, bias_downsample: bool = True, # for backwards compatibility (also see PR #5477) act: str | tuple = ("relu", {"inplace": True}), + norm: str | tuple = "batch", ) -> None: super().__init__() @@ -238,7 +241,6 @@ def __init__( raise ValueError("Unknown block '%s', use basic or bottleneck" % block) conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims] - norm_type: type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims] pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims] avgp_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[ Pool.ADAPTIVEAVG, spatial_dims @@ -262,7 +264,9 @@ def __init__( padding=tuple(k // 2 for k in conv1_kernel_size), bias=False, ) - self.bn1 = norm_type(self.in_planes) + + norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=self.in_planes) + self.bn1 = norm_layer self.act = get_act_layer(name=act) self.maxpool = pool_type(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, block_inplanes[0], layers[0], spatial_dims, shortcut_type) @@ -275,7 +279,7 @@ def __init__( for m in self.modules(): if isinstance(m, conv_type): nn.init.kaiming_normal_(torch.as_tensor(m.weight), mode="fan_out", nonlinearity="relu") - elif isinstance(m, norm_type): + elif isinstance(m, type(norm_layer)): nn.init.constant_(torch.as_tensor(m.weight), 1) nn.init.constant_(torch.as_tensor(m.bias), 0) elif isinstance(m, nn.Linear): @@ -295,9 +299,9 @@ def _make_layer( spatial_dims: int, shortcut_type: str, stride: int = 1, + norm: str | tuple = "batch", ) -> nn.Sequential: conv_type: Callable = Conv[Conv.CONV, spatial_dims] - norm_type: Callable = Norm[Norm.BATCH, spatial_dims] downsample: nn.Module | partial | None = None if stride != 1 or self.in_planes != planes * block.expansion: @@ -317,18 +321,23 @@ def _make_layer( stride=stride, bias=self.bias_downsample, ), - norm_type(planes * block.expansion), + get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes * block.expansion), ) layers = [ block( - in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample + in_planes=self.in_planes, + planes=planes, + spatial_dims=spatial_dims, + stride=stride, + downsample=downsample, + norm=norm, ) ] self.in_planes = planes * block.expansion for _i in range(1, blocks): - layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims)) + layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims, norm=norm)) return nn.Sequential(*layers) diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 3a58d1c955..d20ad89de7 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -202,9 +202,26 @@ (1, 3), ] +TEST_CASE_9 = [ + { + "block": "bottleneck", + "layers": [3, 4, 6, 3], + "block_inplanes": [64, 128, 256, 512], + "spatial_dims": 1, + "n_input_channels": 2, + "num_classes": 3, + "conv1_t_size": [3], + "conv1_t_stride": 1, + "act": ("relu", {"inplace": False}), + "norm": ("layer", {"normalized_shape": (64, 32)}), + }, + (1, 2, 32), + (1, 3), +] + TEST_CASES = [] PRETRAINED_TEST_CASES = [] -for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]: +for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A, TEST_CASE_3]: for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: TEST_CASES.append([model, *case]) PRETRAINED_TEST_CASES.append([model, *case])