From dbe073ab098331efa3c98bb4da6fa9d7f44b7630 Mon Sep 17 00:00:00 2001 From: dongyang0122 Date: Sat, 23 Apr 2022 10:37:46 -0600 Subject: [PATCH] Fixed an error in DiNTS model implementation and enabled act and norm layer options (#4157) * fixed a bug Signed-off-by: dongy * autofix Signed-off-by: dongy * update test case Signed-off-by: dongy Co-authored-by: dongy --- monai/networks/blocks/dints_block.py | 8 +- monai/networks/nets/dints.py | 138 ++++++++++++++++++++++----- tests/test_dints_cell.py | 40 +++++++- tests/test_dints_network.py | 4 +- 4 files changed, 156 insertions(+), 34 deletions(-) diff --git a/monai/networks/blocks/dints_block.py b/monai/networks/blocks/dints_block.py index f76e125fe0..b7365f50e3 100644 --- a/monai/networks/blocks/dints_block.py +++ b/monai/networks/blocks/dints_block.py @@ -31,7 +31,7 @@ def __init__( out_channel: int, spatial_dims: int = 3, act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = "INSTANCE", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), ): """ Args: @@ -82,7 +82,7 @@ def __init__( out_channel: int, spatial_dims: int = 3, act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = "INSTANCE", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), ): """ Args: @@ -150,7 +150,7 @@ def __init__( padding: int, mode: int = 0, act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = "INSTANCE", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), ): """ Args: @@ -235,7 +235,7 @@ def __init__( padding: int = 1, spatial_dims: int = 3, act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = "INSTANCE", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), ): """ Args: diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py index 978695c5d0..b7f3921a47 100644 --- a/monai/networks/nets/dints.py +++ b/monai/networks/nets/dints.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import warnings from typing import List, Optional, Tuple, Union @@ -96,22 +97,47 @@ class _ActiConvNormBlockWithRAMCost(ActiConvNormBlock): ram_cost = total_ram/output_size = 2 * in_channel/out_channel + 1 """ - def __init__(self, in_channel: int, out_channel: int, kernel_size: int, padding: int, spatial_dims: int = 3): - super().__init__(in_channel, out_channel, kernel_size, padding, spatial_dims) + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: int, + padding: int, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + ): + super().__init__(in_channel, out_channel, kernel_size, padding, spatial_dims, act_name, norm_name) self.ram_cost = 1 + in_channel / out_channel * 2 class _P3DActiConvNormBlockWithRAMCost(P3DActiConvNormBlock): - def __init__(self, in_channel: int, out_channel: int, kernel_size: int, padding: int, p3dmode: int = 0): - super().__init__(in_channel, out_channel, kernel_size, padding, p3dmode) + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: int, + padding: int, + p3dmode: int = 0, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + ): + super().__init__(in_channel, out_channel, kernel_size, padding, p3dmode, act_name, norm_name) # 1 in_channel (activation) + 1 in_channel (convolution) + # 1 out_channel (convolution) + 1 out_channel (normalization) self.ram_cost = 2 + 2 * in_channel / out_channel class _FactorizedIncreaseBlockWithRAMCost(FactorizedIncreaseBlock): - def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3): - super().__init__(in_channel, out_channel, spatial_dims) + def __init__( + self, + in_channel: int, + out_channel: int, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + ): + super().__init__(in_channel, out_channel, spatial_dims, act_name, norm_name) # s0 is upsampled 2x from s1, representing feature sizes at two resolutions. # 2 * in_channel * s0 (upsample + activation) + 2 * out_channel * s0 (conv + normalization) # s0 = output_size/out_channel @@ -119,8 +145,15 @@ def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3): class _FactorizedReduceBlockWithRAMCost(FactorizedReduceBlock): - def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3): - super().__init__(in_channel, out_channel, spatial_dims) + def __init__( + self, + in_channel: int, + out_channel: int, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + ): + super().__init__(in_channel, out_channel, spatial_dims, act_name, norm_name) # s0 is upsampled 2x from s1, representing feature sizes at two resolutions. # in_channel * s0 (activation) + 3 * out_channel * s1 (convolution, concatenation, normalization) # s0 = s1 * 2^(spatial_dims) = output_size / out_channel * 2^(spatial_dims) @@ -182,14 +215,6 @@ class Cell(CellInterface): # \ # - Downsample - # Define connection operation set, parameterized by the number of channels - ConnOPS = { - "up": _FactorizedIncreaseBlockWithRAMCost, - "down": _FactorizedReduceBlockWithRAMCost, - "identity": _IdentityWithRAMCost, - "align_channels": _ActiConvNormBlockWithRAMCost, - } - # Define 2D operation set, parameterized by the number of channels OPS2D = { "skip_connect": lambda _c: _IdentityWithRAMCost(), @@ -205,18 +230,69 @@ class Cell(CellInterface): "conv_1x3x3": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=2), } - def __init__(self, c_prev: int, c: int, rate: int, arch_code_c=None, spatial_dims: int = 3): + # Define connection operation set, parameterized by the number of channels + ConnOPS = { + "up": _FactorizedIncreaseBlockWithRAMCost, + "down": _FactorizedReduceBlockWithRAMCost, + "identity": _IdentityWithRAMCost, + "align_channels": _ActiConvNormBlockWithRAMCost, + } + + def __init__( + self, + c_prev: int, + c: int, + rate: int, + arch_code_c=None, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), + ): super().__init__() self._spatial_dims = spatial_dims + self._act_name = act_name + self._norm_name = norm_name + if rate == -1: # downsample - self.preprocess = self.ConnOPS["down"](c_prev, c, spatial_dims=self._spatial_dims) + self.preprocess = self.ConnOPS["down"]( + c_prev, c, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name + ) elif rate == 1: # upsample - self.preprocess = self.ConnOPS["up"](c_prev, c, spatial_dims=self._spatial_dims) + self.preprocess = self.ConnOPS["up"]( + c_prev, c, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name + ) else: if c_prev == c: self.preprocess = self.ConnOPS["identity"]() else: - self.preprocess = self.ConnOPS["align_channels"](c_prev, c, 1, 0, spatial_dims=self._spatial_dims) + self.preprocess = self.ConnOPS["align_channels"]( + c_prev, c, 1, 0, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name + ) + + # Define 2D operation set, parameterized by the number of channels + self.OPS2D = { + "skip_connect": lambda _c: _IdentityWithRAMCost(), + "conv_3x3": lambda c: _ActiConvNormBlockWithRAMCost( + c, c, 3, padding=1, spatial_dims=2, act_name=self._act_name, norm_name=self._norm_name + ), + } + + # Define 3D operation set, parameterized by the number of channels + self.OPS3D = { + "skip_connect": lambda _c: _IdentityWithRAMCost(), + "conv_3x3x3": lambda c: _ActiConvNormBlockWithRAMCost( + c, c, 3, padding=1, spatial_dims=3, act_name=self._act_name, norm_name=self._norm_name + ), + "conv_3x3x1": lambda c: _P3DActiConvNormBlockWithRAMCost( + c, c, 3, padding=1, p3dmode=0, act_name=self._act_name, norm_name=self._norm_name + ), + "conv_3x1x3": lambda c: _P3DActiConvNormBlockWithRAMCost( + c, c, 3, padding=1, p3dmode=1, act_name=self._act_name, norm_name=self._norm_name + ), + "conv_1x3x3": lambda c: _P3DActiConvNormBlockWithRAMCost( + c, c, 3, padding=1, p3dmode=2, act_name=self._act_name, norm_name=self._norm_name + ), + } self.OPS = {} if self._spatial_dims == 2: @@ -283,7 +359,7 @@ def __init__( in_channels: int, num_classes: int, act_name: Union[Tuple, str] = "RELU", - norm_name: Union[Tuple, str] = "INSTANCE", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), spatial_dims: int = 3, use_downsample: bool = True, node_a=None, @@ -398,7 +474,9 @@ def __init__( bias=False, dilation=1, ), - get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx - 1]), + get_norm_layer( + name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[max(res_idx - 1, 0)] + ), nn.Upsample(scale_factor=2 ** (res_idx != 0), mode=mode, align_corners=True), ) @@ -484,6 +562,8 @@ def __init__( num_blocks: int = 6, num_depths: int = 3, spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), use_downsample: bool = True, device: str = "cpu", ): @@ -494,6 +574,8 @@ def __init__( self.num_blocks = num_blocks self.num_depths = num_depths self._spatial_dims = spatial_dims + self._act_name = act_name + self._norm_name = norm_name self.use_downsample = use_downsample self.device = device self.num_cell_ops = 0 @@ -535,6 +617,8 @@ def __init__( self.arch_code2ops[res_idx], self.arch_code_c[blk_idx, res_idx], self._spatial_dims, + self._act_name, + self._norm_name, ) def forward(self, x): @@ -555,6 +639,8 @@ def __init__( num_blocks: int = 6, num_depths: int = 3, spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), use_downsample: bool = True, device: str = "cpu", ): @@ -571,6 +657,8 @@ def __init__( num_blocks=num_blocks, num_depths=num_depths, spatial_dims=spatial_dims, + act_name=act_name, + norm_name=norm_name, use_downsample=use_downsample, device=device, ) @@ -591,7 +679,7 @@ def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: x=inputs[self.arch_code2in[res_idx]], weight=torch.ones_like(self.arch_code_c[blk_idx, res_idx]) ) outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out - inputs = outputs + inputs = outputs return inputs @@ -650,6 +738,8 @@ def __init__( num_blocks: int = 6, num_depths: int = 3, spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}), use_downsample: bool = True, device: str = "cpu", ): @@ -663,6 +753,8 @@ def __init__( num_blocks=num_blocks, num_depths=num_depths, spatial_dims=spatial_dims, + act_name=act_name, + norm_name=norm_name, use_downsample=use_downsample, device=device, ) diff --git a/tests/test_dints_cell.py b/tests/test_dints_cell.py index d480235b70..a5da39bae9 100644 --- a/tests/test_dints_cell.py +++ b/tests/test_dints_cell.py @@ -32,21 +32,28 @@ (2, 4, 64, 32, 16), ], [ - {"c_prev": 8, "c": 8, "rate": 0, "arch_code_c": None}, + {"c_prev": 8, "c": 8, "rate": 0, "arch_code_c": None, "act_name": "SELU", "norm_name": "BATCH"}, torch.tensor([1, 1, 1, 1, 1]), torch.tensor([0, 0, 0, 1, 0]), (2, 8, 32, 16, 8), (2, 8, 32, 16, 8), ], [ - {"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": None}, + { + "c_prev": 8, + "c": 8, + "rate": -1, + "arch_code_c": None, + "act_name": "PRELU", + "norm_name": ("BATCH", {"affine": False}), + }, torch.tensor([1, 1, 1, 1, 1]), torch.tensor([1, 1, 1, 1, 1]), (2, 8, 32, 16, 8), (2, 8, 16, 8, 4), ], [ - {"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1]}, + {"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1], "act_name": "RELU", "norm_name": "INSTANCE"}, torch.tensor([1, 0, 0, 0, 1]), torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]), (2, 8, 32, 16, 8), @@ -56,12 +63,35 @@ TEST_CASES_2D = [ [ - {"c_prev": 8, "c": 7, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1], "spatial_dims": 2}, + { + "c_prev": 8, + "c": 7, + "rate": -1, + "arch_code_c": [1, 0, 0, 0, 1], + "spatial_dims": 2, + "act_name": "PRELU", + "norm_name": ("BATCH", {"affine": False}), + }, torch.tensor([1, 0]), torch.tensor([0.2, 0.2]), (2, 8, 16, 8), (2, 7, 8, 4), - ] + ], + [ + { + "c_prev": 8, + "c": 8, + "rate": -1, + "arch_code_c": None, + "spatial_dims": 2, + "act_name": "SELU", + "norm_name": "INSTANCE", + }, + torch.tensor([1, 0]), + torch.tensor([0.2, 0.2]), + (2, 8, 16, 8), + (2, 8, 8, 4), + ], ] diff --git a/tests/test_dints_network.py b/tests/test_dints_network.py index 8be5eb7ccd..08e75fab98 100644 --- a/tests/test_dints_network.py +++ b/tests/test_dints_network.py @@ -33,7 +33,7 @@ "in_channels": 1, "num_classes": 3, "act_name": "RELU", - "norm_name": "INSTANCE", + "norm_name": ("INSTANCE", {"affine": True}), "use_downsample": False, "spatial_dims": 3, }, @@ -101,7 +101,7 @@ "in_channels": 1, "num_classes": 4, "act_name": "RELU", - "norm_name": "INSTANCE", + "norm_name": ("INSTANCE", {"affine": True}), "use_downsample": False, "spatial_dims": 2, },