Skip to content

Commit

Permalink
[WIP]1442 add LocalNet initialization (#1460)
Browse files Browse the repository at this point in the history
* 1442 add initialization

Signed-off-by: kate-sann5100 <yiwen.li@st-annes.ox.ac.uk>

* 1442 fix typing and add test cases

Signed-off-by: kate-sann5100 <yiwen.li@st-annes.ox.ac.uk>
  • Loading branch information
kate-sann5100 authored Jan 19, 2021
1 parent 91cb8cd commit 012cf62
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 38 deletions.
16 changes: 14 additions & 2 deletions monai/networks/blocks/localnet_block.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Tuple, Type, Union

import torch
from torch import nn
from torch.nn import functional as F

from monai.networks.blocks import Convolution
from monai.networks.layers import same_padding
from monai.networks.layers.factories import Norm, Pool
from monai.networks.layers.factories import Conv, Norm, Pool


def get_conv_block(
Expand Down Expand Up @@ -285,6 +285,7 @@ def __init__(
in_channels: int,
out_channels: int,
act: Optional[Union[Tuple, str]] = "RELU",
initializer: str = "kaiming_uniform",
) -> None:
"""
Args:
Expand All @@ -298,6 +299,17 @@ def __init__(
self.conv_block = get_conv_block(
spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None
)
conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims]
for m in self.conv_block.modules():
if isinstance(m, conv_type):
if initializer == "kaiming_uniform":
nn.init.kaiming_normal_(torch.as_tensor(m.weight))
elif initializer == "zeros":
nn.init.zeros_(torch.as_tensor(m.weight))
else:
raise ValueError(
f"initializer {initializer} is not supported, " "currently supporting kaiming_uniform and zeros"
)

def forward(self, x) -> torch.Tensor:
"""
Expand Down
9 changes: 6 additions & 3 deletions monai/networks/nets/localnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,17 @@ def __init__(
num_channel_initial: int,
extract_levels: List[int],
out_activation: Optional[Union[Tuple, str]],
out_initializer: str = "kaiming_uniform",
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
num_channel_initial: number of initial channels,
extract_levels: number of extraction levels,
out_activation: activation to use at end layer,
num_channel_initial: number of initial channels.
extract_levels: number of extraction levels.
out_activation: activation to use at end layer.
out_initializer: initializer for extraction layers.
"""
super(LocalNet, self).__init__()
self.extract_levels = extract_levels
Expand Down Expand Up @@ -85,6 +87,7 @@ def __init__(
in_channels=num_channels[level],
out_channels=out_channels,
act=out_activation,
initializer=out_initializer,
)
for level in self.extract_levels
]
Expand Down
41 changes: 17 additions & 24 deletions tests/test_localnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,6 @@
device = "cuda" if torch.cuda.is_available() else "cpu"


param_variations_2d = {
"spatial_dims": 2,
"in_channels": 2,
"out_channels": 2,
"num_channel_initial": 16,
"extract_levels": [0, 1, 2],
"out_activation": ["sigmoid", None],
}

TEST_CASE_LOCALNET_2D = [
[
{
Expand All @@ -41,23 +32,25 @@
for num_channel_initial in [4, 16, 32]:
for extract_levels in [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]:
for out_activation in ["sigmoid", None]:
TEST_CASE_LOCALNET_3D.append(
[
{
"spatial_dims": 3,
"in_channels": in_channels,
"out_channels": out_channels,
"num_channel_initial": num_channel_initial,
"extract_levels": extract_levels,
"out_activation": out_activation,
},
(1, in_channels, 16, 16, 16),
(1, out_channels, 16, 16, 16),
]
)
for out_initializer in ["kaiming_uniform", "zeros"]:
TEST_CASE_LOCALNET_3D.append(
[
{
"spatial_dims": 3,
"in_channels": in_channels,
"out_channels": out_channels,
"num_channel_initial": num_channel_initial,
"extract_levels": extract_levels,
"out_activation": out_activation,
"out_initializer": out_initializer,
},
(1, in_channels, 16, 16, 16),
(1, out_channels, 16, 16, 16),
]
)


class TestDynUNet(unittest.TestCase):
class TestLocalNet(unittest.TestCase):
@parameterized.expand(TEST_CASE_LOCALNET_2D + TEST_CASE_LOCALNET_3D)
def test_shape(self, input_param, input_shape, expected_shape):
net = LocalNet(**input_param).to(device)
Expand Down
15 changes: 6 additions & 9 deletions tests/test_localnet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,8 @@
TEST_CASE_UP_SAMPLE = [[{"spatial_dims": spatial_dims, "in_channels": 4, "out_channels": 2}] for spatial_dims in [2, 3]]

TEST_CASE_EXTRACT = [
[
{
"spatial_dims": spatial_dims,
"in_channels": 2,
"out_channels": 3,
"act": act,
}
]
for spatial_dims, act in zip([2, 3], ["sigmoid", None])
[{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 3, "act": act, "initializer": initializer}]
for spatial_dims, act, initializer in zip([2, 3], ["sigmoid", None], ["kaiming_uniform", "zeros"])
]

in_size = 4
Expand Down Expand Up @@ -93,6 +86,10 @@ def test_shape(self, input_param):
result = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)

def test_ill_arg(self):
with self.assertRaises(ValueError):
LocalNetFeatureExtractorBlock(spatial_dims=2, in_channels=2, out_channels=2, initializer="none")


if __name__ == "__main__":
unittest.main()

0 comments on commit 012cf62

Please sign in to comment.