Skip to content

Commit

Permalink
Passing the activation function of the SegResNet to the Residual Bloc…
Browse files Browse the repository at this point in the history
…ks (#3419)

* Update segresnet_block.py

making the activation function a parameter to the class of the ResBlock

* Update segresnet.py

* passing activation to the residual blocks

* little change

* signed

Signed-off-by: Patricio Astudillo <patricio.astudillo@robovision.ai>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* [DLMED] add args and update default (#3418)

Signed-off-by: Nic Ma <nma@nvidia.com>
Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* update docstrings

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

Co-authored-by: Patricio Astudillo <patricio.astudillo@robovision.ai>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Wenqi Li <wenqil@nvidia.com>
Co-authored-by: Nic Ma <nma@nvidia.com>
  • Loading branch information
5 people authored Dec 1, 2021
1 parent fc9abb9 commit 0b077da
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 16 deletions.
19 changes: 13 additions & 6 deletions monai/networks/blocks/segresnet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@

from monai.networks.blocks.convolutions import Convolution
from monai.networks.blocks.upsample import UpSample
from monai.networks.layers.factories import Act
from monai.networks.layers.utils import get_norm_layer
from monai.networks.layers.utils import get_act_layer, get_norm_layer
from monai.utils import InterpolateMode, UpsampleMode


Expand Down Expand Up @@ -50,13 +49,21 @@ class ResBlock(nn.Module):
<https://arxiv.org/pdf/1810.11654.pdf>`_.
"""

def __init__(self, spatial_dims: int, in_channels: int, norm: Union[Tuple, str], kernel_size: int = 3) -> None:
def __init__(
self,
spatial_dims: int,
in_channels: int,
norm: Union[Tuple, str],
kernel_size: int = 3,
act: Union[Tuple, str] = ("RELU", {"inplace": True}),
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions, could be 1, 2 or 3.
in_channels: number of input channels.
norm: feature normalization type and arguments.
kernel_size: convolution kernel size, the value should be an odd number. Defaults to 3.
act: activation type and arguments. Defaults to ``RELU``.
"""

super().__init__()
Expand All @@ -66,7 +73,7 @@ def __init__(self, spatial_dims: int, in_channels: int, norm: Union[Tuple, str],

self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
self.relu = Act[Act.RELU](inplace=True)
self.act = get_act_layer(act)
self.conv1 = get_conv_layer(spatial_dims, in_channels=in_channels, out_channels=in_channels)
self.conv2 = get_conv_layer(spatial_dims, in_channels=in_channels, out_channels=in_channels)

Expand All @@ -75,11 +82,11 @@ def forward(self, x):
identity = x

x = self.norm1(x)
x = self.relu(x)
x = self.act(x)
x = self.conv1(x)

x = self.norm2(x)
x = self.relu(x)
x = self.act(x)
x = self.conv2(x)

x += identity
Expand Down
23 changes: 14 additions & 9 deletions monai/networks/nets/segresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,16 @@ def __init__(
super().__init__()

if spatial_dims not in (2, 3):
raise AssertionError("spatial_dims can only be 2 or 3.")
raise ValueError("`spatial_dims` can only be 2 or 3.")

self.spatial_dims = spatial_dims
self.init_filters = init_filters
self.in_channels = in_channels
self.blocks_down = blocks_down
self.blocks_up = blocks_up
self.dropout_prob = dropout_prob
self.act = get_act_layer(act)
self.act = act # input options
self.act_mod = get_act_layer(act)
if norm_name:
if norm_name.lower() != "group":
raise ValueError(f"Deprecating option 'norm_name={norm_name}', please use 'norm' instead.")
Expand All @@ -108,7 +109,8 @@ def _make_down_layers(self):
else nn.Identity()
)
down_layer = nn.Sequential(
pre_conv, *[ResBlock(spatial_dims, layer_in_channels, norm=norm) for _ in range(blocks_down[i])]
pre_conv,
*[ResBlock(spatial_dims, layer_in_channels, norm=norm, act=self.act) for _ in range(blocks_down[i])],
)
down_layers.append(down_layer)
return down_layers
Expand All @@ -127,7 +129,10 @@ def _make_up_layers(self):
sample_in_channels = filters * 2 ** (n_up - i)
up_layers.append(
nn.Sequential(
*[ResBlock(spatial_dims, sample_in_channels // 2, norm=norm) for _ in range(blocks_up[i])]
*[
ResBlock(spatial_dims, sample_in_channels // 2, norm=norm, act=self.act)
for _ in range(blocks_up[i])
]
)
)
up_samples.append(
Expand All @@ -143,7 +148,7 @@ def _make_up_layers(self):
def _make_final_conv(self, out_channels: int):
return nn.Sequential(
get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=self.init_filters),
self.act,
self.act_mod,
get_conv_layer(self.spatial_dims, self.init_filters, out_channels, kernel_size=1, bias=True),
)

Expand Down Expand Up @@ -262,10 +267,10 @@ def _prepare_vae_modules(self):

self.vae_down = nn.Sequential(
get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=v_filters),
self.act,
self.act_mod,
get_conv_layer(self.spatial_dims, v_filters, self.smallest_filters, stride=2, bias=True),
get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=self.smallest_filters),
self.act,
self.act_mod,
)
self.vae_fc1 = nn.Linear(total_elements, self.vae_nz)
self.vae_fc2 = nn.Linear(total_elements, self.vae_nz)
Expand All @@ -275,7 +280,7 @@ def _prepare_vae_modules(self):
get_conv_layer(self.spatial_dims, self.smallest_filters, v_filters, kernel_size=1),
get_upsample_layer(self.spatial_dims, v_filters, upsample_mode=self.upsample_mode),
get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=v_filters),
self.act,
self.act_mod,
)

def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor):
Expand Down Expand Up @@ -304,7 +309,7 @@ def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor):
x_vae = z_mean + z_sigma * z_mean_rand

x_vae = self.vae_fc3(x_vae)
x_vae = self.act(x_vae)
x_vae = self.act_mod(x_vae)
x_vae = x_vae.view([-1, self.smallest_filters] + self.fc_insize)
x_vae = self.vae_fc_up_sample(x_vae)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_segresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_shape(self, input_param, input_shape, expected_shape):
self.assertEqual(result.shape, expected_shape)

def test_ill_arg(self):
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
SegResNet(spatial_dims=4)

def test_script(self):
Expand Down

0 comments on commit 0b077da

Please sign in to comment.