From 132aa373d773f7a7a104eeb09da55e612036b190 Mon Sep 17 00:00:00 2001
From: Wenqi Li <wenqil@nvidia.com>
Date: Mon, 13 Sep 2021 19:20:09 +0100
Subject: [PATCH] 2715 enhance resnet downsampling block (#2937)

* con1_padding -> conv1_padding

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* simpler init.

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* fixes 2715

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* adds 3d tests

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* fixes flake8 error

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
---
 monai/networks/nets/resnet.py | 39 ++++++++++++++++++++---------------
 tests/test_resnet.py          | 14 ++++++++++++-
 2 files changed, 35 insertions(+), 18 deletions(-)

diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py
index a5e6b7ab811..3b86dc3d62f 100644
--- a/monai/networks/nets/resnet.py
+++ b/monai/networks/nets/resnet.py
@@ -14,9 +14,10 @@
 
 import torch
 import torch.nn as nn
-import torch.nn.functional as F
 
 from monai.networks.layers.factories import Conv, Norm, Pool
+from monai.networks.layers.utils import get_pool_layer
+from monai.utils.module import look_up_option
 
 __all__ = ["ResNet", "resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"]
 
@@ -162,7 +163,9 @@ class ResNet(nn.Module):
         conv1_t_size: size of first convolution layer, determines kernel and padding.
         conv1_t_stride: stride of first convolution layer.
         no_max_pool: bool argument to determine if to use maxpool layer.
-        shortcut_type: which downsample block to use.
+        shortcut_type: which downsample block to use. Options are 'A', 'B', default to 'B'.
+            - 'A': using `self._downsample_basic_block`.
+            - 'B': kernel_size 1 conv + norm.
         widen_factor: widen output for each layer.
         num_classes: number of output (classifications)
     """
@@ -198,7 +201,7 @@ def __init__(
         ]
 
         block_avgpool = get_avgpool()
-        conv1_kernel, conv1_stride, con1_padding = get_conv1(conv1_t_size, conv1_t_stride)
+        conv1_kernel, conv1_stride, conv1_padding = get_conv1(conv1_t_size, conv1_t_stride)
         block_inplanes = [int(x * widen_factor) for x in block_inplanes]
 
         self.in_planes = block_inplanes[0]
@@ -209,7 +212,7 @@ def __init__(
             self.in_planes,
             kernel_size=conv1_kernel[spatial_dims],
             stride=conv1_stride[spatial_dims],
-            padding=con1_padding[spatial_dims],
+            padding=conv1_padding[spatial_dims],
             bias=False,
         )
         self.bn1 = norm_type(self.in_planes)
@@ -234,14 +237,9 @@ def __init__(
                 nn.init.constant_(torch.as_tensor(m.bias), 0)
 
     def _downsample_basic_block(self, x: torch.Tensor, planes: int, stride: int, spatial_dims: int = 3) -> torch.Tensor:
-        assert spatial_dims == 3
-        out: torch.Tensor = F.avg_pool3d(x, kernel_size=1, stride=stride)
-        zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4))
-        if isinstance(out.data, torch.FloatTensor):
-            zero_pads = zero_pads.cuda()
-
+        out: torch.Tensor = get_pool_layer(("avg", {"kernel_size": 1, "stride": stride}), spatial_dims=spatial_dims)(x)
+        zero_pads = torch.zeros(out.size(0), planes - out.size(1), *out.shape[2:], dtype=out.dtype, device=out.device)
         out = torch.cat([out.data, zero_pads], dim=1)
-
         return out
 
     def _make_layer(
@@ -259,9 +257,12 @@ def _make_layer(
 
         downsample: Union[nn.Module, partial, None] = None
         if stride != 1 or self.in_planes != planes * block.expansion:
-            if shortcut_type == "A":
+            if look_up_option(shortcut_type, {"A", "B"}) == "A":
                 downsample = partial(
-                    self._downsample_basic_block, planes=planes * block.expansion, kernel_size=1, stride=stride
+                    self._downsample_basic_block,
+                    planes=planes * block.expansion,
+                    stride=stride,
+                    spatial_dims=spatial_dims,
                 )
             else:
                 downsample = nn.Sequential(
@@ -269,12 +270,16 @@ def _make_layer(
                     norm_type(planes * block.expansion),
                 )
 
-        layers = []
-        layers.append(
+        layers = [
             block(
-                in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample
+                in_planes=self.in_planes,
+                planes=planes,
+                spatial_dims=spatial_dims,
+                stride=stride,
+                downsample=downsample,
             )
-        )
+        ]
+
         self.in_planes = planes * block.expansion
         for _i in range(1, blocks):
             layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims))
diff --git a/tests/test_resnet.py b/tests/test_resnet.py
index c4ba5c2e167..16cd6f4865e 100644
--- a/tests/test_resnet.py
+++ b/tests/test_resnet.py
@@ -42,14 +42,26 @@
     (2, 3),
 ]
 
+TEST_CASE_2_A = [  # 2D, batch 2, 1 input channel, shortcut type A
+    {"pretrained": False, "spatial_dims": 2, "n_input_channels": 1, "num_classes": 3, "shortcut_type": "A"},
+    (2, 1, 32, 64),
+    (2, 3),
+]
+
 TEST_CASE_3 = [  # 1D, batch 1, 2 input channels
     {"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3},
     (1, 2, 32),
     (1, 3),
 ]
 
+TEST_CASE_3_A = [  # 1D, batch 1, 2 input channels
+    {"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3, "shortcut_type": "A"},
+    (1, 2, 32),
+    (1, 3),
+]
+
 TEST_CASES = []
-for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]:
+for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
     for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
         TEST_CASES.append([model, *case])