Skip to content

Commit

Permalink
Add norm param to ResNet (#7752)
Browse files Browse the repository at this point in the history
Fixes #7294  .

### Description
Adds a `norm` param to ResNet

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Peter Kaplinsky <peterkaplinsky@gmail.com>
Co-authored-by: Peter Kaplinsky <peterkaplinsky@gmail.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
3 people authored May 23, 2024
1 parent 66a2fae commit 373c003
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 25 deletions.
24 changes: 16 additions & 8 deletions monai/networks/nets/daf3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from collections import OrderedDict
from collections.abc import Callable, Sequence
from functools import partial

import torch
import torch.nn as nn
Expand All @@ -25,6 +26,7 @@
from monai.networks.blocks.convolutions import Convolution
from monai.networks.blocks.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork
from monai.networks.layers.factories import Conv, Norm
from monai.networks.layers.utils import get_norm_layer
from monai.networks.nets.resnet import ResNet, ResNetBottleneck

__all__ = [
Expand Down Expand Up @@ -170,27 +172,31 @@ class Daf3dResNetBottleneck(ResNetBottleneck):
spatial_dims: number of spatial dimensions of the input image.
stride: stride to use for second conv layer.
downsample: which downsample layer to use.
norm: which normalization layer to use. Defaults to group.
"""

expansion = 2

def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None):
norm_type: Callable = Norm[Norm.GROUP, spatial_dims]
def __init__(
self, in_planes, planes, spatial_dims=3, stride=1, downsample=None, norm=("group", {"num_groups": 32})
):
conv_type: Callable = Conv[Conv.CONV, spatial_dims]

norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims)

# in case downsample uses batch norm, change to group norm
if isinstance(downsample, nn.Sequential):
downsample = nn.Sequential(
conv_type(in_planes, planes * self.expansion, kernel_size=1, stride=stride, bias=False),
norm_type(num_groups=32, num_channels=planes * self.expansion),
norm_layer(channels=planes * self.expansion),
)

super().__init__(in_planes, planes, spatial_dims, stride, downsample)

# change norm from batch to group norm
self.bn1 = norm_type(num_groups=32, num_channels=planes)
self.bn2 = norm_type(num_groups=32, num_channels=planes)
self.bn3 = norm_type(num_groups=32, num_channels=planes * self.expansion)
self.bn1 = norm_layer(channels=planes)
self.bn2 = norm_layer(channels=planes)
self.bn3 = norm_layer(channels=planes * self.expansion)

# adapt second convolution to work with groups
self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, stride=stride, groups=32, bias=False)
Expand All @@ -212,8 +218,10 @@ class Daf3dResNetDilatedBottleneck(Daf3dResNetBottleneck):
downsample: which downsample layer to use.
"""

def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None):
super().__init__(in_planes, planes, spatial_dims, stride, downsample)
def __init__(
self, in_planes, planes, spatial_dims=3, stride=1, downsample=None, norm=("group", {"num_groups": 32})
):
super().__init__(in_planes, planes, spatial_dims, stride, downsample, norm)

# add dilation in second convolution
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
Expand Down
44 changes: 28 additions & 16 deletions monai/networks/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import torch.nn as nn

from monai.networks.blocks.encoder import BaseEncoder
from monai.networks.layers.factories import Conv, Norm, Pool
from monai.networks.layers.utils import get_act_layer, get_pool_layer
from monai.networks.layers.factories import Conv, Pool
from monai.networks.layers.utils import get_act_layer, get_norm_layer, get_pool_layer
from monai.utils import ensure_tuple_rep
from monai.utils.module import look_up_option, optional_import

Expand Down Expand Up @@ -79,6 +79,7 @@ def __init__(
stride: int = 1,
downsample: nn.Module | partial | None = None,
act: str | tuple = ("relu", {"inplace": True}),
norm: str | tuple = "batch",
) -> None:
"""
Args:
Expand All @@ -88,17 +89,18 @@ def __init__(
stride: stride to use for first conv layer.
downsample: which downsample layer to use.
act: activation type and arguments. Defaults to relu.
norm: feature normalization type and arguments. Defaults to batch norm.
"""
super().__init__()

conv_type: Callable = Conv[Conv.CONV, spatial_dims]
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes)

self.conv1 = conv_type(in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
self.bn1 = norm_type(planes)
self.bn1 = norm_layer
self.act = get_act_layer(name=act)
self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False)
self.bn2 = norm_type(planes)
self.bn2 = norm_layer
self.downsample = downsample
self.stride = stride

Expand Down Expand Up @@ -132,6 +134,7 @@ def __init__(
stride: int = 1,
downsample: nn.Module | partial | None = None,
act: str | tuple = ("relu", {"inplace": True}),
norm: str | tuple = "batch",
) -> None:
"""
Args:
Expand All @@ -141,19 +144,20 @@ def __init__(
stride: stride to use for second conv layer.
downsample: which downsample layer to use.
act: activation type and arguments. Defaults to relu.
norm: feature normalization type and arguments. Defaults to batch norm.
"""

super().__init__()

conv_type: Callable = Conv[Conv.CONV, spatial_dims]
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims)

self.conv1 = conv_type(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = norm_type(planes)
self.bn1 = norm_layer(channels=planes)
self.conv2 = conv_type(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = norm_type(planes)
self.bn2 = norm_layer(channels=planes)
self.conv3 = conv_type(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = norm_type(planes * self.expansion)
self.bn3 = norm_layer(channels=planes * self.expansion)
self.act = get_act_layer(name=act)
self.downsample = downsample
self.stride = stride
Expand Down Expand Up @@ -207,6 +211,7 @@ class ResNet(nn.Module):
feed_forward: whether to add the FC layer for the output, default to `True`.
bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`.
act: activation type and arguments. Defaults to relu.
norm: feature normalization type and arguments. Defaults to batch norm.
"""

Expand All @@ -226,6 +231,7 @@ def __init__(
feed_forward: bool = True,
bias_downsample: bool = True, # for backwards compatibility (also see PR #5477)
act: str | tuple = ("relu", {"inplace": True}),
norm: str | tuple = "batch",
) -> None:
super().__init__()

Expand All @@ -238,7 +244,6 @@ def __init__(
raise ValueError("Unknown block '%s', use basic or bottleneck" % block)

conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]
norm_type: type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims]
pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]
avgp_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[
Pool.ADAPTIVEAVG, spatial_dims
Expand All @@ -262,7 +267,9 @@ def __init__(
padding=tuple(k // 2 for k in conv1_kernel_size),
bias=False,
)
self.bn1 = norm_type(self.in_planes)

norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=self.in_planes)
self.bn1 = norm_layer
self.act = get_act_layer(name=act)
self.maxpool = pool_type(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, block_inplanes[0], layers[0], spatial_dims, shortcut_type)
Expand All @@ -275,7 +282,7 @@ def __init__(
for m in self.modules():
if isinstance(m, conv_type):
nn.init.kaiming_normal_(torch.as_tensor(m.weight), mode="fan_out", nonlinearity="relu")
elif isinstance(m, norm_type):
elif isinstance(m, type(norm_layer)):
nn.init.constant_(torch.as_tensor(m.weight), 1)
nn.init.constant_(torch.as_tensor(m.bias), 0)
elif isinstance(m, nn.Linear):
Expand All @@ -295,9 +302,9 @@ def _make_layer(
spatial_dims: int,
shortcut_type: str,
stride: int = 1,
norm: str | tuple = "batch",
) -> nn.Sequential:
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]

downsample: nn.Module | partial | None = None
if stride != 1 or self.in_planes != planes * block.expansion:
Expand All @@ -317,18 +324,23 @@ def _make_layer(
stride=stride,
bias=self.bias_downsample,
),
norm_type(planes * block.expansion),
get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes * block.expansion),
)

layers = [
block(
in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample
in_planes=self.in_planes,
planes=planes,
spatial_dims=spatial_dims,
stride=stride,
downsample=downsample,
norm=norm,
)
]

self.in_planes = planes * block.expansion
for _i in range(1, blocks):
layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims))
layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims, norm=norm))

return nn.Sequential(*layers)

Expand Down
19 changes: 18 additions & 1 deletion tests/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,30 @@
(1, 3),
]

TEST_CASE_9 = [ # Layer norm
{
"block": ResNetBlock,
"layers": [3, 4, 6, 3],
"block_inplanes": [64, 128, 256, 512],
"spatial_dims": 1,
"n_input_channels": 2,
"num_classes": 3,
"conv1_t_size": [3],
"conv1_t_stride": 1,
"act": ("relu", {"inplace": False}),
"norm": ("layer", {"normalized_shape": (64, 32)}),
},
(1, 2, 32),
(1, 3),
]

TEST_CASES = []
PRETRAINED_TEST_CASES = []
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
TEST_CASES.append([model, *case])
PRETRAINED_TEST_CASES.append([model, *case])
for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]:
for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]:
TEST_CASES.append([ResNet, *case])

TEST_SCRIPT_CASES = [
Expand Down

0 comments on commit 373c003

Please sign in to comment.