Skip to content

Commit

Permalink
Merge branch 'dev' into torch_ThresholdIntensity
Browse files Browse the repository at this point in the history
  • Loading branch information
wyli authored Sep 13, 2021
2 parents 7274199 + 132aa37 commit cc84392
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 18 deletions.
39 changes: 22 additions & 17 deletions monai/networks/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.layers.factories import Conv, Norm, Pool
from monai.networks.layers.utils import get_pool_layer
from monai.utils.module import look_up_option

__all__ = ["ResNet", "resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"]

Expand Down Expand Up @@ -162,7 +163,9 @@ class ResNet(nn.Module):
conv1_t_size: size of first convolution layer, determines kernel and padding.
conv1_t_stride: stride of first convolution layer.
no_max_pool: bool argument to determine if to use maxpool layer.
shortcut_type: which downsample block to use.
shortcut_type: which downsample block to use. Options are 'A', 'B', default to 'B'.
- 'A': using `self._downsample_basic_block`.
- 'B': kernel_size 1 conv + norm.
widen_factor: widen output for each layer.
num_classes: number of output (classifications)
"""
Expand Down Expand Up @@ -198,7 +201,7 @@ def __init__(
]

block_avgpool = get_avgpool()
conv1_kernel, conv1_stride, con1_padding = get_conv1(conv1_t_size, conv1_t_stride)
conv1_kernel, conv1_stride, conv1_padding = get_conv1(conv1_t_size, conv1_t_stride)
block_inplanes = [int(x * widen_factor) for x in block_inplanes]

self.in_planes = block_inplanes[0]
Expand All @@ -209,7 +212,7 @@ def __init__(
self.in_planes,
kernel_size=conv1_kernel[spatial_dims],
stride=conv1_stride[spatial_dims],
padding=con1_padding[spatial_dims],
padding=conv1_padding[spatial_dims],
bias=False,
)
self.bn1 = norm_type(self.in_planes)
Expand All @@ -234,14 +237,9 @@ def __init__(
nn.init.constant_(torch.as_tensor(m.bias), 0)

def _downsample_basic_block(self, x: torch.Tensor, planes: int, stride: int, spatial_dims: int = 3) -> torch.Tensor:
assert spatial_dims == 3
out: torch.Tensor = F.avg_pool3d(x, kernel_size=1, stride=stride)
zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4))
if isinstance(out.data, torch.FloatTensor):
zero_pads = zero_pads.cuda()

out: torch.Tensor = get_pool_layer(("avg", {"kernel_size": 1, "stride": stride}), spatial_dims=spatial_dims)(x)
zero_pads = torch.zeros(out.size(0), planes - out.size(1), *out.shape[2:], dtype=out.dtype, device=out.device)
out = torch.cat([out.data, zero_pads], dim=1)

return out

def _make_layer(
Expand All @@ -259,22 +257,29 @@ def _make_layer(

downsample: Union[nn.Module, partial, None] = None
if stride != 1 or self.in_planes != planes * block.expansion:
if shortcut_type == "A":
if look_up_option(shortcut_type, {"A", "B"}) == "A":
downsample = partial(
self._downsample_basic_block, planes=planes * block.expansion, kernel_size=1, stride=stride
self._downsample_basic_block,
planes=planes * block.expansion,
stride=stride,
spatial_dims=spatial_dims,
)
else:
downsample = nn.Sequential(
conv_type(self.in_planes, planes * block.expansion, kernel_size=1, stride=stride),
norm_type(planes * block.expansion),
)

layers = []
layers.append(
layers = [
block(
in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample
in_planes=self.in_planes,
planes=planes,
spatial_dims=spatial_dims,
stride=stride,
downsample=downsample,
)
)
]

self.in_planes = planes * block.expansion
for _i in range(1, blocks):
layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims))
Expand Down
14 changes: 13 additions & 1 deletion tests/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,26 @@
(2, 3),
]

TEST_CASE_2_A = [ # 2D, batch 2, 1 input channel, shortcut type A
{"pretrained": False, "spatial_dims": 2, "n_input_channels": 1, "num_classes": 3, "shortcut_type": "A"},
(2, 1, 32, 64),
(2, 3),
]

TEST_CASE_3 = [ # 1D, batch 1, 2 input channels
{"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3},
(1, 2, 32),
(1, 3),
]

TEST_CASE_3_A = [ # 1D, batch 1, 2 input channels
{"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3, "shortcut_type": "A"},
(1, 2, 32),
(1, 3),
]

TEST_CASES = []
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]:
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
TEST_CASES.append([model, *case])

Expand Down

0 comments on commit cc84392

Please sign in to comment.