Skip to content

Commit

Permalink
Fixed an error in DiNTS model implementation and enabled act and norm…
Browse files Browse the repository at this point in the history
… layer options (Project-MONAI#4157)

* fixed a bug

Signed-off-by: dongy <dongy@nvidia.com>

* autofix

Signed-off-by: dongy <dongy@nvidia.com>

* update test case

Signed-off-by: dongy <dongy@nvidia.com>

Co-authored-by: dongy <dongy@nvidia.com>
  • Loading branch information
2 people authored and Can-Zhao committed May 10, 2022
1 parent df26fec commit dbe073a
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 34 deletions.
8 changes: 4 additions & 4 deletions monai/networks/blocks/dints_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
out_channel: int,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = "INSTANCE",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
"""
Args:
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(
out_channel: int,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = "INSTANCE",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
"""
Args:
Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(
padding: int,
mode: int = 0,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = "INSTANCE",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
"""
Args:
Expand Down Expand Up @@ -235,7 +235,7 @@ def __init__(
padding: int = 1,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = "INSTANCE",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
"""
Args:
Expand Down
138 changes: 115 additions & 23 deletions monai/networks/nets/dints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import warnings
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -96,31 +97,63 @@ class _ActiConvNormBlockWithRAMCost(ActiConvNormBlock):
ram_cost = total_ram/output_size = 2 * in_channel/out_channel + 1
"""

def __init__(self, in_channel: int, out_channel: int, kernel_size: int, padding: int, spatial_dims: int = 3):
super().__init__(in_channel, out_channel, kernel_size, padding, spatial_dims)
def __init__(
self,
in_channel: int,
out_channel: int,
kernel_size: int,
padding: int,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
super().__init__(in_channel, out_channel, kernel_size, padding, spatial_dims, act_name, norm_name)
self.ram_cost = 1 + in_channel / out_channel * 2


class _P3DActiConvNormBlockWithRAMCost(P3DActiConvNormBlock):
def __init__(self, in_channel: int, out_channel: int, kernel_size: int, padding: int, p3dmode: int = 0):
super().__init__(in_channel, out_channel, kernel_size, padding, p3dmode)
def __init__(
self,
in_channel: int,
out_channel: int,
kernel_size: int,
padding: int,
p3dmode: int = 0,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
super().__init__(in_channel, out_channel, kernel_size, padding, p3dmode, act_name, norm_name)
# 1 in_channel (activation) + 1 in_channel (convolution) +
# 1 out_channel (convolution) + 1 out_channel (normalization)
self.ram_cost = 2 + 2 * in_channel / out_channel


class _FactorizedIncreaseBlockWithRAMCost(FactorizedIncreaseBlock):
def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3):
super().__init__(in_channel, out_channel, spatial_dims)
def __init__(
self,
in_channel: int,
out_channel: int,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
super().__init__(in_channel, out_channel, spatial_dims, act_name, norm_name)
# s0 is upsampled 2x from s1, representing feature sizes at two resolutions.
# 2 * in_channel * s0 (upsample + activation) + 2 * out_channel * s0 (conv + normalization)
# s0 = output_size/out_channel
self.ram_cost = 2 * in_channel / out_channel + 2


class _FactorizedReduceBlockWithRAMCost(FactorizedReduceBlock):
def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3):
super().__init__(in_channel, out_channel, spatial_dims)
def __init__(
self,
in_channel: int,
out_channel: int,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
super().__init__(in_channel, out_channel, spatial_dims, act_name, norm_name)
# s0 is upsampled 2x from s1, representing feature sizes at two resolutions.
# in_channel * s0 (activation) + 3 * out_channel * s1 (convolution, concatenation, normalization)
# s0 = s1 * 2^(spatial_dims) = output_size / out_channel * 2^(spatial_dims)
Expand Down Expand Up @@ -182,14 +215,6 @@ class Cell(CellInterface):
# \
# - Downsample

# Define connection operation set, parameterized by the number of channels
ConnOPS = {
"up": _FactorizedIncreaseBlockWithRAMCost,
"down": _FactorizedReduceBlockWithRAMCost,
"identity": _IdentityWithRAMCost,
"align_channels": _ActiConvNormBlockWithRAMCost,
}

# Define 2D operation set, parameterized by the number of channels
OPS2D = {
"skip_connect": lambda _c: _IdentityWithRAMCost(),
Expand All @@ -205,18 +230,69 @@ class Cell(CellInterface):
"conv_1x3x3": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=2),
}

def __init__(self, c_prev: int, c: int, rate: int, arch_code_c=None, spatial_dims: int = 3):
# Define connection operation set, parameterized by the number of channels
ConnOPS = {
"up": _FactorizedIncreaseBlockWithRAMCost,
"down": _FactorizedReduceBlockWithRAMCost,
"identity": _IdentityWithRAMCost,
"align_channels": _ActiConvNormBlockWithRAMCost,
}

def __init__(
self,
c_prev: int,
c: int,
rate: int,
arch_code_c=None,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
super().__init__()
self._spatial_dims = spatial_dims
self._act_name = act_name
self._norm_name = norm_name

if rate == -1: # downsample
self.preprocess = self.ConnOPS["down"](c_prev, c, spatial_dims=self._spatial_dims)
self.preprocess = self.ConnOPS["down"](
c_prev, c, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name
)
elif rate == 1: # upsample
self.preprocess = self.ConnOPS["up"](c_prev, c, spatial_dims=self._spatial_dims)
self.preprocess = self.ConnOPS["up"](
c_prev, c, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name
)
else:
if c_prev == c:
self.preprocess = self.ConnOPS["identity"]()
else:
self.preprocess = self.ConnOPS["align_channels"](c_prev, c, 1, 0, spatial_dims=self._spatial_dims)
self.preprocess = self.ConnOPS["align_channels"](
c_prev, c, 1, 0, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name
)

# Define 2D operation set, parameterized by the number of channels
self.OPS2D = {
"skip_connect": lambda _c: _IdentityWithRAMCost(),
"conv_3x3": lambda c: _ActiConvNormBlockWithRAMCost(
c, c, 3, padding=1, spatial_dims=2, act_name=self._act_name, norm_name=self._norm_name
),
}

# Define 3D operation set, parameterized by the number of channels
self.OPS3D = {
"skip_connect": lambda _c: _IdentityWithRAMCost(),
"conv_3x3x3": lambda c: _ActiConvNormBlockWithRAMCost(
c, c, 3, padding=1, spatial_dims=3, act_name=self._act_name, norm_name=self._norm_name
),
"conv_3x3x1": lambda c: _P3DActiConvNormBlockWithRAMCost(
c, c, 3, padding=1, p3dmode=0, act_name=self._act_name, norm_name=self._norm_name
),
"conv_3x1x3": lambda c: _P3DActiConvNormBlockWithRAMCost(
c, c, 3, padding=1, p3dmode=1, act_name=self._act_name, norm_name=self._norm_name
),
"conv_1x3x3": lambda c: _P3DActiConvNormBlockWithRAMCost(
c, c, 3, padding=1, p3dmode=2, act_name=self._act_name, norm_name=self._norm_name
),
}

self.OPS = {}
if self._spatial_dims == 2:
Expand Down Expand Up @@ -283,7 +359,7 @@ def __init__(
in_channels: int,
num_classes: int,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = "INSTANCE",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
spatial_dims: int = 3,
use_downsample: bool = True,
node_a=None,
Expand Down Expand Up @@ -398,7 +474,9 @@ def __init__(
bias=False,
dilation=1,
),
get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx - 1]),
get_norm_layer(
name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[max(res_idx - 1, 0)]
),
nn.Upsample(scale_factor=2 ** (res_idx != 0), mode=mode, align_corners=True),
)

Expand Down Expand Up @@ -484,6 +562,8 @@ def __init__(
num_blocks: int = 6,
num_depths: int = 3,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
use_downsample: bool = True,
device: str = "cpu",
):
Expand All @@ -494,6 +574,8 @@ def __init__(
self.num_blocks = num_blocks
self.num_depths = num_depths
self._spatial_dims = spatial_dims
self._act_name = act_name
self._norm_name = norm_name
self.use_downsample = use_downsample
self.device = device
self.num_cell_ops = 0
Expand Down Expand Up @@ -535,6 +617,8 @@ def __init__(
self.arch_code2ops[res_idx],
self.arch_code_c[blk_idx, res_idx],
self._spatial_dims,
self._act_name,
self._norm_name,
)

def forward(self, x):
Expand All @@ -555,6 +639,8 @@ def __init__(
num_blocks: int = 6,
num_depths: int = 3,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
use_downsample: bool = True,
device: str = "cpu",
):
Expand All @@ -571,6 +657,8 @@ def __init__(
num_blocks=num_blocks,
num_depths=num_depths,
spatial_dims=spatial_dims,
act_name=act_name,
norm_name=norm_name,
use_downsample=use_downsample,
device=device,
)
Expand All @@ -591,7 +679,7 @@ def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
x=inputs[self.arch_code2in[res_idx]], weight=torch.ones_like(self.arch_code_c[blk_idx, res_idx])
)
outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out
inputs = outputs
inputs = outputs

return inputs

Expand Down Expand Up @@ -650,6 +738,8 @@ def __init__(
num_blocks: int = 6,
num_depths: int = 3,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
use_downsample: bool = True,
device: str = "cpu",
):
Expand All @@ -663,6 +753,8 @@ def __init__(
num_blocks=num_blocks,
num_depths=num_depths,
spatial_dims=spatial_dims,
act_name=act_name,
norm_name=norm_name,
use_downsample=use_downsample,
device=device,
)
Expand Down
40 changes: 35 additions & 5 deletions tests/test_dints_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,28 @@
(2, 4, 64, 32, 16),
],
[
{"c_prev": 8, "c": 8, "rate": 0, "arch_code_c": None},
{"c_prev": 8, "c": 8, "rate": 0, "arch_code_c": None, "act_name": "SELU", "norm_name": "BATCH"},
torch.tensor([1, 1, 1, 1, 1]),
torch.tensor([0, 0, 0, 1, 0]),
(2, 8, 32, 16, 8),
(2, 8, 32, 16, 8),
],
[
{"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": None},
{
"c_prev": 8,
"c": 8,
"rate": -1,
"arch_code_c": None,
"act_name": "PRELU",
"norm_name": ("BATCH", {"affine": False}),
},
torch.tensor([1, 1, 1, 1, 1]),
torch.tensor([1, 1, 1, 1, 1]),
(2, 8, 32, 16, 8),
(2, 8, 16, 8, 4),
],
[
{"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1]},
{"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1], "act_name": "RELU", "norm_name": "INSTANCE"},
torch.tensor([1, 0, 0, 0, 1]),
torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]),
(2, 8, 32, 16, 8),
Expand All @@ -56,12 +63,35 @@

TEST_CASES_2D = [
[
{"c_prev": 8, "c": 7, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1], "spatial_dims": 2},
{
"c_prev": 8,
"c": 7,
"rate": -1,
"arch_code_c": [1, 0, 0, 0, 1],
"spatial_dims": 2,
"act_name": "PRELU",
"norm_name": ("BATCH", {"affine": False}),
},
torch.tensor([1, 0]),
torch.tensor([0.2, 0.2]),
(2, 8, 16, 8),
(2, 7, 8, 4),
]
],
[
{
"c_prev": 8,
"c": 8,
"rate": -1,
"arch_code_c": None,
"spatial_dims": 2,
"act_name": "SELU",
"norm_name": "INSTANCE",
},
torch.tensor([1, 0]),
torch.tensor([0.2, 0.2]),
(2, 8, 16, 8),
(2, 8, 8, 4),
],
]


Expand Down
4 changes: 2 additions & 2 deletions tests/test_dints_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"in_channels": 1,
"num_classes": 3,
"act_name": "RELU",
"norm_name": "INSTANCE",
"norm_name": ("INSTANCE", {"affine": True}),
"use_downsample": False,
"spatial_dims": 3,
},
Expand Down Expand Up @@ -101,7 +101,7 @@
"in_channels": 1,
"num_classes": 4,
"act_name": "RELU",
"norm_name": "INSTANCE",
"norm_name": ("INSTANCE", {"affine": True}),
"use_downsample": False,
"spatial_dims": 2,
},
Expand Down

0 comments on commit dbe073a

Please sign in to comment.