From 132aa373d773f7a7a104eeb09da55e612036b190 Mon Sep 17 00:00:00 2001 From: Wenqi Li <wenqil@nvidia.com> Date: Mon, 13 Sep 2021 19:20:09 +0100 Subject: [PATCH] 2715 enhance resnet downsampling block (#2937) * con1_padding -> conv1_padding Signed-off-by: Wenqi Li <wenqil@nvidia.com> * simpler init. Signed-off-by: Wenqi Li <wenqil@nvidia.com> * fixes 2715 Signed-off-by: Wenqi Li <wenqil@nvidia.com> * adds 3d tests Signed-off-by: Wenqi Li <wenqil@nvidia.com> * fixes flake8 error Signed-off-by: Wenqi Li <wenqil@nvidia.com> --- monai/networks/nets/resnet.py | 39 ++++++++++++++++++++--------------- tests/test_resnet.py | 14 ++++++++++++- 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index a5e6b7ab811..3b86dc3d62f 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -14,9 +14,10 @@ import torch import torch.nn as nn -import torch.nn.functional as F from monai.networks.layers.factories import Conv, Norm, Pool +from monai.networks.layers.utils import get_pool_layer +from monai.utils.module import look_up_option __all__ = ["ResNet", "resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"] @@ -162,7 +163,9 @@ class ResNet(nn.Module): conv1_t_size: size of first convolution layer, determines kernel and padding. conv1_t_stride: stride of first convolution layer. no_max_pool: bool argument to determine if to use maxpool layer. - shortcut_type: which downsample block to use. + shortcut_type: which downsample block to use. Options are 'A', 'B', default to 'B'. + - 'A': using `self._downsample_basic_block`. + - 'B': kernel_size 1 conv + norm. widen_factor: widen output for each layer. num_classes: number of output (classifications) """ @@ -198,7 +201,7 @@ def __init__( ] block_avgpool = get_avgpool() - conv1_kernel, conv1_stride, con1_padding = get_conv1(conv1_t_size, conv1_t_stride) + conv1_kernel, conv1_stride, conv1_padding = get_conv1(conv1_t_size, conv1_t_stride) block_inplanes = [int(x * widen_factor) for x in block_inplanes] self.in_planes = block_inplanes[0] @@ -209,7 +212,7 @@ def __init__( self.in_planes, kernel_size=conv1_kernel[spatial_dims], stride=conv1_stride[spatial_dims], - padding=con1_padding[spatial_dims], + padding=conv1_padding[spatial_dims], bias=False, ) self.bn1 = norm_type(self.in_planes) @@ -234,14 +237,9 @@ def __init__( nn.init.constant_(torch.as_tensor(m.bias), 0) def _downsample_basic_block(self, x: torch.Tensor, planes: int, stride: int, spatial_dims: int = 3) -> torch.Tensor: - assert spatial_dims == 3 - out: torch.Tensor = F.avg_pool3d(x, kernel_size=1, stride=stride) - zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4)) - if isinstance(out.data, torch.FloatTensor): - zero_pads = zero_pads.cuda() - + out: torch.Tensor = get_pool_layer(("avg", {"kernel_size": 1, "stride": stride}), spatial_dims=spatial_dims)(x) + zero_pads = torch.zeros(out.size(0), planes - out.size(1), *out.shape[2:], dtype=out.dtype, device=out.device) out = torch.cat([out.data, zero_pads], dim=1) - return out def _make_layer( @@ -259,9 +257,12 @@ def _make_layer( downsample: Union[nn.Module, partial, None] = None if stride != 1 or self.in_planes != planes * block.expansion: - if shortcut_type == "A": + if look_up_option(shortcut_type, {"A", "B"}) == "A": downsample = partial( - self._downsample_basic_block, planes=planes * block.expansion, kernel_size=1, stride=stride + self._downsample_basic_block, + planes=planes * block.expansion, + stride=stride, + spatial_dims=spatial_dims, ) else: downsample = nn.Sequential( @@ -269,12 +270,16 @@ def _make_layer( norm_type(planes * block.expansion), ) - layers = [] - layers.append( + layers = [ block( - in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample + in_planes=self.in_planes, + planes=planes, + spatial_dims=spatial_dims, + stride=stride, + downsample=downsample, ) - ) + ] + self.in_planes = planes * block.expansion for _i in range(1, blocks): layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims)) diff --git a/tests/test_resnet.py b/tests/test_resnet.py index c4ba5c2e167..16cd6f4865e 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -42,14 +42,26 @@ (2, 3), ] +TEST_CASE_2_A = [ # 2D, batch 2, 1 input channel, shortcut type A + {"pretrained": False, "spatial_dims": 2, "n_input_channels": 1, "num_classes": 3, "shortcut_type": "A"}, + (2, 1, 32, 64), + (2, 3), +] + TEST_CASE_3 = [ # 1D, batch 1, 2 input channels {"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3}, (1, 2, 32), (1, 3), ] +TEST_CASE_3_A = [ # 1D, batch 1, 2 input channels + {"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3, "shortcut_type": "A"}, + (1, 2, 32), + (1, 3), +] + TEST_CASES = [] -for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]: +for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]: for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: TEST_CASES.append([model, *case])