Skip to content

Commit

Permalink
add layer norm to resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
Peter Kaplinsky authored and Pkaps25 committed May 13, 2024
1 parent daf2e71 commit 335ab0f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 17 deletions.
41 changes: 25 additions & 16 deletions monai/networks/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import torch.nn as nn

from monai.networks.blocks.encoder import BaseEncoder
from monai.networks.layers.factories import Conv, Norm, Pool
from monai.networks.layers.utils import get_act_layer, get_pool_layer
from monai.networks.layers.factories import Conv, Pool
from monai.networks.layers.utils import get_act_layer, get_norm_layer, get_pool_layer
from monai.utils import ensure_tuple_rep
from monai.utils.module import look_up_option, optional_import

Expand Down Expand Up @@ -79,6 +79,7 @@ def __init__(
stride: int = 1,
downsample: nn.Module | partial | None = None,
act: str | tuple = ("relu", {"inplace": True}),
norm: str | tuple = "batch",
) -> None:
"""
Args:
Expand All @@ -92,13 +93,13 @@ def __init__(
super().__init__()

conv_type: Callable = Conv[Conv.CONV, spatial_dims]
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes)

self.conv1 = conv_type(in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
self.bn1 = norm_type(planes)
self.bn1 = norm_layer
self.act = get_act_layer(name=act)
self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False)
self.bn2 = norm_type(planes)
self.bn2 = norm_layer
self.downsample = downsample
self.stride = stride

Expand Down Expand Up @@ -132,6 +133,7 @@ def __init__(
stride: int = 1,
downsample: nn.Module | partial | None = None,
act: str | tuple = ("relu", {"inplace": True}),
norm: str | tuple = "batch",
) -> None:
"""
Args:
Expand All @@ -146,14 +148,14 @@ def __init__(
super().__init__()

conv_type: Callable = Conv[Conv.CONV, spatial_dims]
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims)

self.conv1 = conv_type(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = norm_type(planes)
self.bn1 = norm_layer(channels=planes)
self.conv2 = conv_type(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = norm_type(planes)
self.bn2 = norm_layer(channels=planes)
self.conv3 = conv_type(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = norm_type(planes * self.expansion)
self.bn3 = norm_layer(channels=planes * self.expansion)
self.act = get_act_layer(name=act)
self.downsample = downsample
self.stride = stride
Expand Down Expand Up @@ -226,6 +228,7 @@ def __init__(
feed_forward: bool = True,
bias_downsample: bool = True, # for backwards compatibility (also see PR #5477)
act: str | tuple = ("relu", {"inplace": True}),
norm: str | tuple = "batch",
) -> None:
super().__init__()

Expand All @@ -238,7 +241,6 @@ def __init__(
raise ValueError("Unknown block '%s', use basic or bottleneck" % block)

conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]
norm_type: type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims]
pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]
avgp_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[
Pool.ADAPTIVEAVG, spatial_dims
Expand All @@ -262,7 +264,9 @@ def __init__(
padding=tuple(k // 2 for k in conv1_kernel_size),
bias=False,
)
self.bn1 = norm_type(self.in_planes)

norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=self.in_planes)
self.bn1 = norm_layer
self.act = get_act_layer(name=act)
self.maxpool = pool_type(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, block_inplanes[0], layers[0], spatial_dims, shortcut_type)
Expand All @@ -275,7 +279,7 @@ def __init__(
for m in self.modules():
if isinstance(m, conv_type):
nn.init.kaiming_normal_(torch.as_tensor(m.weight), mode="fan_out", nonlinearity="relu")
elif isinstance(m, norm_type):
elif isinstance(m, type(norm_layer)):
nn.init.constant_(torch.as_tensor(m.weight), 1)
nn.init.constant_(torch.as_tensor(m.bias), 0)
elif isinstance(m, nn.Linear):
Expand All @@ -295,9 +299,9 @@ def _make_layer(
spatial_dims: int,
shortcut_type: str,
stride: int = 1,
norm: str | tuple = "batch",
) -> nn.Sequential:
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]

downsample: nn.Module | partial | None = None
if stride != 1 or self.in_planes != planes * block.expansion:
Expand All @@ -317,18 +321,23 @@ def _make_layer(
stride=stride,
bias=self.bias_downsample,
),
norm_type(planes * block.expansion),
get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes * block.expansion),
)

layers = [
block(
in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample
in_planes=self.in_planes,
planes=planes,
spatial_dims=spatial_dims,
stride=stride,
downsample=downsample,
norm=norm,
)
]

self.in_planes = planes * block.expansion
for _i in range(1, blocks):
layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims))
layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims, norm=norm))

return nn.Sequential(*layers)

Expand Down
19 changes: 18 additions & 1 deletion tests/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,26 @@
(1, 3),
]

TEST_CASE_9 = [
{
"block": "bottleneck",
"layers": [3, 4, 6, 3],
"block_inplanes": [64, 128, 256, 512],
"spatial_dims": 1,
"n_input_channels": 2,
"num_classes": 3,
"conv1_t_size": [3],
"conv1_t_stride": 1,
"act": ("relu", {"inplace": False}),
"norm": ("layer", {"normalized_shape": (64, 32)}),
},
(1, 2, 32),
(1, 3),
]

TEST_CASES = []
PRETRAINED_TEST_CASES = []
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A, TEST_CASE_3]:
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
TEST_CASES.append([model, *case])
PRETRAINED_TEST_CASES.append([model, *case])
Expand Down

0 comments on commit 335ab0f

Please sign in to comment.