Skip to content

Commit

Permalink
daf3d compatibility with resnet
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Kaplinsky <peterkaplinsky@gmail.com>
  • Loading branch information
Peter Kaplinsky authored and Pkaps25 committed May 13, 2024
1 parent f4275d3 commit 1c13d22
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
20 changes: 12 additions & 8 deletions monai/networks/nets/daf3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from collections import OrderedDict
from collections.abc import Callable, Sequence
from functools import partial

import torch
import torch.nn as nn
Expand All @@ -25,6 +26,7 @@
from monai.networks.blocks.convolutions import Convolution
from monai.networks.blocks.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork
from monai.networks.layers.factories import Conv, Norm
from monai.networks.layers.utils import get_norm_layer
from monai.networks.nets.resnet import ResNet, ResNetBottleneck

__all__ = [
Expand Down Expand Up @@ -170,27 +172,29 @@ class Daf3dResNetBottleneck(ResNetBottleneck):
spatial_dims: number of spatial dimensions of the input image.
stride: stride to use for second conv layer.
downsample: which downsample layer to use.
norm: which normalization layer to use. Defaults to group.
"""

expansion = 2

def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None):
norm_type: Callable = Norm[Norm.GROUP, spatial_dims]
def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None, norm=("group", {"num_groups": 32})):
conv_type: Callable = Conv[Conv.CONV, spatial_dims]

norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims)

# in case downsample uses batch norm, change to group norm
if isinstance(downsample, nn.Sequential):
downsample = nn.Sequential(
conv_type(in_planes, planes * self.expansion, kernel_size=1, stride=stride, bias=False),
norm_type(num_groups=32, num_channels=planes * self.expansion),
norm_layer(channels=planes * self.expansion),
)

super().__init__(in_planes, planes, spatial_dims, stride, downsample)

# change norm from batch to group norm
self.bn1 = norm_type(num_groups=32, num_channels=planes)
self.bn2 = norm_type(num_groups=32, num_channels=planes)
self.bn3 = norm_type(num_groups=32, num_channels=planes * self.expansion)
self.bn1 = norm_layer(channels=planes)
self.bn2 = norm_layer(channels=planes)
self.bn3 = norm_layer(channels=planes * self.expansion)

# adapt second convolution to work with groups
self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, stride=stride, groups=32, bias=False)
Expand All @@ -212,8 +216,8 @@ class Daf3dResNetDilatedBottleneck(Daf3dResNetBottleneck):
downsample: which downsample layer to use.
"""

def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None):
super().__init__(in_planes, planes, spatial_dims, stride, downsample)
def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None, norm=("group", {"num_groups": 32})):
super().__init__(in_planes, planes, spatial_dims, stride, downsample, norm)

# add dilation in second convolution
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@
(1, 3),
]

TEST_CASE_9 = [
TEST_CASE_9 = [ # Layer norm
{
"block": ResNetBlock,
"layers": [3, 4, 6, 3],
Expand Down

0 comments on commit 1c13d22

Please sign in to comment.