Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add norm param to ResNet #7752

Merged
merged 8 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions monai/networks/nets/daf3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from collections import OrderedDict
from collections.abc import Callable, Sequence
from functools import partial

import torch
import torch.nn as nn
Expand All @@ -25,6 +26,7 @@
from monai.networks.blocks.convolutions import Convolution
from monai.networks.blocks.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork
from monai.networks.layers.factories import Conv, Norm
from monai.networks.layers.utils import get_norm_layer
from monai.networks.nets.resnet import ResNet, ResNetBottleneck

__all__ = [
Expand Down Expand Up @@ -170,27 +172,31 @@ class Daf3dResNetBottleneck(ResNetBottleneck):
spatial_dims: number of spatial dimensions of the input image.
stride: stride to use for second conv layer.
downsample: which downsample layer to use.
norm: which normalization layer to use. Defaults to group.
"""

expansion = 2

def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None):
norm_type: Callable = Norm[Norm.GROUP, spatial_dims]
def __init__(
self, in_planes, planes, spatial_dims=3, stride=1, downsample=None, norm=("group", {"num_groups": 32})
):
conv_type: Callable = Conv[Conv.CONV, spatial_dims]

norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims)

# in case downsample uses batch norm, change to group norm
if isinstance(downsample, nn.Sequential):
downsample = nn.Sequential(
conv_type(in_planes, planes * self.expansion, kernel_size=1, stride=stride, bias=False),
norm_type(num_groups=32, num_channels=planes * self.expansion),
norm_layer(channels=planes * self.expansion),
)

super().__init__(in_planes, planes, spatial_dims, stride, downsample)

# change norm from batch to group norm
self.bn1 = norm_type(num_groups=32, num_channels=planes)
self.bn2 = norm_type(num_groups=32, num_channels=planes)
self.bn3 = norm_type(num_groups=32, num_channels=planes * self.expansion)
self.bn1 = norm_layer(channels=planes)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
self.bn2 = norm_layer(channels=planes)
self.bn3 = norm_layer(channels=planes * self.expansion)

# adapt second convolution to work with groups
self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, stride=stride, groups=32, bias=False)
Expand All @@ -212,8 +218,10 @@ class Daf3dResNetDilatedBottleneck(Daf3dResNetBottleneck):
downsample: which downsample layer to use.
"""

def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None):
super().__init__(in_planes, planes, spatial_dims, stride, downsample)
def __init__(
self, in_planes, planes, spatial_dims=3, stride=1, downsample=None, norm=("group", {"num_groups": 32})
):
super().__init__(in_planes, planes, spatial_dims, stride, downsample, norm)

# add dilation in second convolution
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
Expand Down
44 changes: 28 additions & 16 deletions monai/networks/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import torch.nn as nn

from monai.networks.blocks.encoder import BaseEncoder
from monai.networks.layers.factories import Conv, Norm, Pool
from monai.networks.layers.utils import get_act_layer, get_pool_layer
from monai.networks.layers.factories import Conv, Pool
from monai.networks.layers.utils import get_act_layer, get_norm_layer, get_pool_layer
from monai.utils import ensure_tuple_rep
from monai.utils.module import look_up_option, optional_import

Expand Down Expand Up @@ -79,6 +79,7 @@ def __init__(
stride: int = 1,
downsample: nn.Module | partial | None = None,
act: str | tuple = ("relu", {"inplace": True}),
norm: str | tuple = "batch",
) -> None:
"""
Args:
Expand All @@ -88,17 +89,18 @@ def __init__(
stride: stride to use for first conv layer.
downsample: which downsample layer to use.
act: activation type and arguments. Defaults to relu.
norm: feature normalization type and arguments. Defaults to batch norm.
"""
super().__init__()

conv_type: Callable = Conv[Conv.CONV, spatial_dims]
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes)

self.conv1 = conv_type(in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
self.bn1 = norm_type(planes)
self.bn1 = norm_layer
self.act = get_act_layer(name=act)
self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False)
self.bn2 = norm_type(planes)
self.bn2 = norm_layer
self.downsample = downsample
self.stride = stride

Expand Down Expand Up @@ -132,6 +134,7 @@ def __init__(
stride: int = 1,
downsample: nn.Module | partial | None = None,
act: str | tuple = ("relu", {"inplace": True}),
norm: str | tuple = "batch",
) -> None:
"""
Args:
Expand All @@ -141,19 +144,20 @@ def __init__(
stride: stride to use for second conv layer.
downsample: which downsample layer to use.
act: activation type and arguments. Defaults to relu.
norm: feature normalization type and arguments. Defaults to batch norm.
"""

super().__init__()

conv_type: Callable = Conv[Conv.CONV, spatial_dims]
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims)

self.conv1 = conv_type(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = norm_type(planes)
self.bn1 = norm_layer(channels=planes)
self.conv2 = conv_type(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = norm_type(planes)
self.bn2 = norm_layer(channels=planes)
self.conv3 = conv_type(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = norm_type(planes * self.expansion)
self.bn3 = norm_layer(channels=planes * self.expansion)
self.act = get_act_layer(name=act)
self.downsample = downsample
self.stride = stride
Expand Down Expand Up @@ -207,6 +211,7 @@ class ResNet(nn.Module):
feed_forward: whether to add the FC layer for the output, default to `True`.
bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`.
act: activation type and arguments. Defaults to relu.
norm: feature normalization type and arguments. Defaults to batch norm.

"""

Expand All @@ -226,6 +231,7 @@ def __init__(
feed_forward: bool = True,
bias_downsample: bool = True, # for backwards compatibility (also see PR #5477)
act: str | tuple = ("relu", {"inplace": True}),
norm: str | tuple = "batch",
) -> None:
super().__init__()

Expand All @@ -238,7 +244,6 @@ def __init__(
raise ValueError("Unknown block '%s', use basic or bottleneck" % block)

conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]
norm_type: type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims]
pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]
avgp_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[
Pool.ADAPTIVEAVG, spatial_dims
Expand All @@ -262,7 +267,9 @@ def __init__(
padding=tuple(k // 2 for k in conv1_kernel_size),
bias=False,
)
self.bn1 = norm_type(self.in_planes)

norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=self.in_planes)
self.bn1 = norm_layer
self.act = get_act_layer(name=act)
self.maxpool = pool_type(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, block_inplanes[0], layers[0], spatial_dims, shortcut_type)
Expand All @@ -275,7 +282,7 @@ def __init__(
for m in self.modules():
if isinstance(m, conv_type):
nn.init.kaiming_normal_(torch.as_tensor(m.weight), mode="fan_out", nonlinearity="relu")
elif isinstance(m, norm_type):
elif isinstance(m, type(norm_layer)):
nn.init.constant_(torch.as_tensor(m.weight), 1)
nn.init.constant_(torch.as_tensor(m.bias), 0)
elif isinstance(m, nn.Linear):
Expand All @@ -295,9 +302,9 @@ def _make_layer(
spatial_dims: int,
shortcut_type: str,
stride: int = 1,
norm: str | tuple = "batch",
) -> nn.Sequential:
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]

downsample: nn.Module | partial | None = None
if stride != 1 or self.in_planes != planes * block.expansion:
Expand All @@ -317,18 +324,23 @@ def _make_layer(
stride=stride,
bias=self.bias_downsample,
),
norm_type(planes * block.expansion),
get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes * block.expansion),
)

layers = [
block(
in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample
in_planes=self.in_planes,
planes=planes,
spatial_dims=spatial_dims,
stride=stride,
downsample=downsample,
norm=norm,
)
]

self.in_planes = planes * block.expansion
for _i in range(1, blocks):
layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims))
layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims, norm=norm))

return nn.Sequential(*layers)

Expand Down
19 changes: 18 additions & 1 deletion tests/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,30 @@
(1, 3),
]

TEST_CASE_9 = [ # Layer norm
{
"block": ResNetBlock,
"layers": [3, 4, 6, 3],
"block_inplanes": [64, 128, 256, 512],
"spatial_dims": 1,
"n_input_channels": 2,
"num_classes": 3,
"conv1_t_size": [3],
"conv1_t_stride": 1,
"act": ("relu", {"inplace": False}),
"norm": ("layer", {"normalized_shape": (64, 32)}),
},
(1, 2, 32),
(1, 3),
]

TEST_CASES = []
PRETRAINED_TEST_CASES = []
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
TEST_CASES.append([model, *case])
PRETRAINED_TEST_CASES.append([model, *case])
for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]:
for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]:
TEST_CASES.append([ResNet, *case])

TEST_SCRIPT_CASES = [
Expand Down
Loading