Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed an error in DiNTS model implementation and enabled act and norm layer options #4157

Merged
merged 5 commits into from
Apr 23, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions monai/networks/blocks/dints_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
out_channel: int,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = "INSTANCE",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
"""
Args:
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(
out_channel: int,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = "INSTANCE",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
"""
Args:
Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(
padding: int,
mode: int = 0,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = "INSTANCE",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
"""
Args:
Expand Down Expand Up @@ -235,7 +235,7 @@ def __init__(
padding: int = 1,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = "INSTANCE",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
"""
Args:
Expand Down
138 changes: 115 additions & 23 deletions monai/networks/nets/dints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import warnings
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -96,31 +97,63 @@ class _ActiConvNormBlockWithRAMCost(ActiConvNormBlock):
ram_cost = total_ram/output_size = 2 * in_channel/out_channel + 1
"""

def __init__(self, in_channel: int, out_channel: int, kernel_size: int, padding: int, spatial_dims: int = 3):
super().__init__(in_channel, out_channel, kernel_size, padding, spatial_dims)
def __init__(
self,
in_channel: int,
out_channel: int,
kernel_size: int,
padding: int,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
super().__init__(in_channel, out_channel, kernel_size, padding, spatial_dims, act_name, norm_name)
self.ram_cost = 1 + in_channel / out_channel * 2


class _P3DActiConvNormBlockWithRAMCost(P3DActiConvNormBlock):
def __init__(self, in_channel: int, out_channel: int, kernel_size: int, padding: int, p3dmode: int = 0):
super().__init__(in_channel, out_channel, kernel_size, padding, p3dmode)
def __init__(
self,
in_channel: int,
out_channel: int,
kernel_size: int,
padding: int,
p3dmode: int = 0,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
super().__init__(in_channel, out_channel, kernel_size, padding, p3dmode, act_name, norm_name)
# 1 in_channel (activation) + 1 in_channel (convolution) +
# 1 out_channel (convolution) + 1 out_channel (normalization)
self.ram_cost = 2 + 2 * in_channel / out_channel


class _FactorizedIncreaseBlockWithRAMCost(FactorizedIncreaseBlock):
def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3):
super().__init__(in_channel, out_channel, spatial_dims)
def __init__(
self,
in_channel: int,
out_channel: int,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
super().__init__(in_channel, out_channel, spatial_dims, act_name, norm_name)
# s0 is upsampled 2x from s1, representing feature sizes at two resolutions.
# 2 * in_channel * s0 (upsample + activation) + 2 * out_channel * s0 (conv + normalization)
# s0 = output_size/out_channel
self.ram_cost = 2 * in_channel / out_channel + 2


class _FactorizedReduceBlockWithRAMCost(FactorizedReduceBlock):
def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3):
super().__init__(in_channel, out_channel, spatial_dims)
def __init__(
self,
in_channel: int,
out_channel: int,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
super().__init__(in_channel, out_channel, spatial_dims, act_name, norm_name)
# s0 is upsampled 2x from s1, representing feature sizes at two resolutions.
# in_channel * s0 (activation) + 3 * out_channel * s1 (convolution, concatenation, normalization)
# s0 = s1 * 2^(spatial_dims) = output_size / out_channel * 2^(spatial_dims)
Expand Down Expand Up @@ -182,14 +215,6 @@ class Cell(CellInterface):
# \
# - Downsample

# Define connection operation set, parameterized by the number of channels
ConnOPS = {
"up": _FactorizedIncreaseBlockWithRAMCost,
"down": _FactorizedReduceBlockWithRAMCost,
"identity": _IdentityWithRAMCost,
"align_channels": _ActiConvNormBlockWithRAMCost,
}

# Define 2D operation set, parameterized by the number of channels
OPS2D = {
"skip_connect": lambda _c: _IdentityWithRAMCost(),
Expand All @@ -205,18 +230,69 @@ class Cell(CellInterface):
"conv_1x3x3": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=2),
}

def __init__(self, c_prev: int, c: int, rate: int, arch_code_c=None, spatial_dims: int = 3):
# Define connection operation set, parameterized by the number of channels
ConnOPS = {
"up": _FactorizedIncreaseBlockWithRAMCost,
"down": _FactorizedReduceBlockWithRAMCost,
"identity": _IdentityWithRAMCost,
"align_channels": _ActiConvNormBlockWithRAMCost,
}

def __init__(
self,
c_prev: int,
c: int,
rate: int,
arch_code_c=None,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
):
super().__init__()
self._spatial_dims = spatial_dims
self._act_name = act_name
self._norm_name = norm_name

if rate == -1: # downsample
self.preprocess = self.ConnOPS["down"](c_prev, c, spatial_dims=self._spatial_dims)
self.preprocess = self.ConnOPS["down"](
c_prev, c, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name
)
elif rate == 1: # upsample
self.preprocess = self.ConnOPS["up"](c_prev, c, spatial_dims=self._spatial_dims)
self.preprocess = self.ConnOPS["up"](
c_prev, c, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name
)
else:
if c_prev == c:
self.preprocess = self.ConnOPS["identity"]()
else:
self.preprocess = self.ConnOPS["align_channels"](c_prev, c, 1, 0, spatial_dims=self._spatial_dims)
self.preprocess = self.ConnOPS["align_channels"](
c_prev, c, 1, 0, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name
)

# Define 2D operation set, parameterized by the number of channels
self.OPS2D = {
"skip_connect": lambda _c: _IdentityWithRAMCost(),
"conv_3x3": lambda c: _ActiConvNormBlockWithRAMCost(
c, c, 3, padding=1, spatial_dims=2, act_name=self._act_name, norm_name=self._norm_name
),
}

# Define 3D operation set, parameterized by the number of channels
self.OPS3D = {
"skip_connect": lambda _c: _IdentityWithRAMCost(),
"conv_3x3x3": lambda c: _ActiConvNormBlockWithRAMCost(
c, c, 3, padding=1, spatial_dims=3, act_name=self._act_name, norm_name=self._norm_name
),
"conv_3x3x1": lambda c: _P3DActiConvNormBlockWithRAMCost(
c, c, 3, padding=1, p3dmode=0, act_name=self._act_name, norm_name=self._norm_name
),
"conv_3x1x3": lambda c: _P3DActiConvNormBlockWithRAMCost(
c, c, 3, padding=1, p3dmode=1, act_name=self._act_name, norm_name=self._norm_name
),
"conv_1x3x3": lambda c: _P3DActiConvNormBlockWithRAMCost(
c, c, 3, padding=1, p3dmode=2, act_name=self._act_name, norm_name=self._norm_name
),
}

self.OPS = {}
if self._spatial_dims == 2:
Expand Down Expand Up @@ -283,7 +359,7 @@ def __init__(
in_channels: int,
num_classes: int,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = "INSTANCE",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
spatial_dims: int = 3,
use_downsample: bool = True,
node_a=None,
Expand Down Expand Up @@ -398,7 +474,9 @@ def __init__(
bias=False,
dilation=1,
),
get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx - 1]),
get_norm_layer(
name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[max(res_idx - 1, 0)]
),
nn.Upsample(scale_factor=2 ** (res_idx != 0), mode=mode, align_corners=True),
)

Expand Down Expand Up @@ -484,6 +562,8 @@ def __init__(
num_blocks: int = 6,
num_depths: int = 3,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
use_downsample: bool = True,
device: str = "cpu",
):
Expand All @@ -494,6 +574,8 @@ def __init__(
self.num_blocks = num_blocks
self.num_depths = num_depths
self._spatial_dims = spatial_dims
self._act_name = act_name
self._norm_name = norm_name
self.use_downsample = use_downsample
self.device = device
self.num_cell_ops = 0
Expand Down Expand Up @@ -535,6 +617,8 @@ def __init__(
self.arch_code2ops[res_idx],
self.arch_code_c[blk_idx, res_idx],
self._spatial_dims,
self._act_name,
self._norm_name,
)

def forward(self, x):
Expand All @@ -555,6 +639,8 @@ def __init__(
num_blocks: int = 6,
num_depths: int = 3,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
use_downsample: bool = True,
device: str = "cpu",
):
Expand All @@ -571,6 +657,8 @@ def __init__(
num_blocks=num_blocks,
num_depths=num_depths,
spatial_dims=spatial_dims,
act_name=act_name,
norm_name=norm_name,
use_downsample=use_downsample,
device=device,
)
Expand All @@ -591,7 +679,7 @@ def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
x=inputs[self.arch_code2in[res_idx]], weight=torch.ones_like(self.arch_code_c[blk_idx, res_idx])
)
outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out
inputs = outputs
inputs = outputs

return inputs

Expand Down Expand Up @@ -650,6 +738,8 @@ def __init__(
num_blocks: int = 6,
num_depths: int = 3,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
use_downsample: bool = True,
device: str = "cpu",
):
Expand All @@ -663,6 +753,8 @@ def __init__(
num_blocks=num_blocks,
num_depths=num_depths,
spatial_dims=spatial_dims,
act_name=act_name,
norm_name=norm_name,
use_downsample=use_downsample,
device=device,
)
Expand Down
40 changes: 35 additions & 5 deletions tests/test_dints_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,28 @@
(2, 4, 64, 32, 16),
],
[
{"c_prev": 8, "c": 8, "rate": 0, "arch_code_c": None},
{"c_prev": 8, "c": 8, "rate": 0, "arch_code_c": None, "act_name": "SELU", "norm_name": "BATCH"},
torch.tensor([1, 1, 1, 1, 1]),
torch.tensor([0, 0, 0, 1, 0]),
(2, 8, 32, 16, 8),
(2, 8, 32, 16, 8),
],
[
{"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": None},
{
"c_prev": 8,
"c": 8,
"rate": -1,
"arch_code_c": None,
"act_name": "PRELU",
"norm_name": ("BATCH", {"affine": False}),
},
torch.tensor([1, 1, 1, 1, 1]),
torch.tensor([1, 1, 1, 1, 1]),
(2, 8, 32, 16, 8),
(2, 8, 16, 8, 4),
],
[
{"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1]},
{"c_prev": 8, "c": 8, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1], "act_name": "RELU", "norm_name": "INSTANCE"},
torch.tensor([1, 0, 0, 0, 1]),
torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]),
(2, 8, 32, 16, 8),
Expand All @@ -56,12 +63,35 @@

TEST_CASES_2D = [
[
{"c_prev": 8, "c": 7, "rate": -1, "arch_code_c": [1, 0, 0, 0, 1], "spatial_dims": 2},
{
"c_prev": 8,
"c": 7,
"rate": -1,
"arch_code_c": [1, 0, 0, 0, 1],
"spatial_dims": 2,
"act_name": "PRELU",
"norm_name": ("BATCH", {"affine": False}),
},
torch.tensor([1, 0]),
torch.tensor([0.2, 0.2]),
(2, 8, 16, 8),
(2, 7, 8, 4),
]
],
[
{
"c_prev": 8,
"c": 8,
"rate": -1,
"arch_code_c": None,
"spatial_dims": 2,
"act_name": "SELU",
"norm_name": "INSTANCE",
},
torch.tensor([1, 0]),
torch.tensor([0.2, 0.2]),
(2, 8, 16, 8),
(2, 8, 8, 4),
],
]


Expand Down
4 changes: 2 additions & 2 deletions tests/test_dints_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"in_channels": 1,
"num_classes": 3,
"act_name": "RELU",
"norm_name": "INSTANCE",
"norm_name": ("INSTANCE", {"affine": True}),
"use_downsample": False,
"spatial_dims": 3,
},
Expand Down Expand Up @@ -101,7 +101,7 @@
"in_channels": 1,
"num_classes": 4,
"act_name": "RELU",
"norm_name": "INSTANCE",
"norm_name": ("INSTANCE", {"affine": True}),
"use_downsample": False,
"spatial_dims": 2,
},
Expand Down