Skip to content

Commit

Permalink
3293 Remove extra deep supervision modules of DynUNet (#3427)
Browse files Browse the repository at this point in the history
* enhance dynunet

Signed-off-by: Yiheng Wang <vennw@nvidia.com>

* fix black issue

Signed-off-by: Yiheng Wang <vennw@nvidia.com>

* use strict=False

Signed-off-by: Yiheng Wang <vennw@nvidia.com>

* fix black 21.12 error

Signed-off-by: Yiheng Wang <vennw@nvidia.com>

* enhance code and update docstring

Signed-off-by: Yiheng Wang <vennw@nvidia.com>
  • Loading branch information
yiheng-wang-nv authored Dec 7, 2021
1 parent a17813b commit 98c1c43
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 42 deletions.
98 changes: 59 additions & 39 deletions monai/networks/nets/dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ class DynUNetSkipLayer(nn.Module):
forward passes of the network.
"""

heads: List[torch.Tensor]
heads: Optional[List[torch.Tensor]]

def __init__(self, index, heads, downsample, upsample, super_head, next_layer):
def __init__(self, index, downsample, upsample, next_layer, heads=None, super_head=None):
super().__init__()
self.downsample = downsample
self.upsample = upsample
self.next_layer = next_layer
self.upsample = upsample
self.super_head = super_head
self.heads = heads
self.index = index
Expand All @@ -46,8 +46,8 @@ def forward(self, x):
downout = self.downsample(x)
nextout = self.next_layer(downout)
upout = self.upsample(nextout, downout)

self.heads[self.index] = self.super_head(upout)
if self.super_head is not None and self.heads is not None and self.index > 0:
self.heads[self.index - 1] = self.super_head(upout)

return upout

Expand Down Expand Up @@ -79,6 +79,8 @@ class DynUNet(nn.Module):
For example, if `strides=((1, 2, 4), 2, 1, 1)`, the minimal spatial size of the input is `(8, 16, 32)`, and
the spatial size of the output is `(8, 8, 8)`.
For backwards compatibility with old weights, please set `strict=False` when calling `load_state_dict`.
Usage example with medical segmentation decathlon dataset is available at:
https://github.com/Project-MONAI/tutorials/tree/master/modules/dynunet_pipeline.
Expand All @@ -100,18 +102,16 @@ class DynUNet(nn.Module):
norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``.
act_name: activation layer type and arguments. Defaults to ``leakyrelu``.
deep_supervision: whether to add deep supervision head before output. Defaults to ``False``.
If ``True``, in training mode, the forward function will output not only the last feature
map, but also the previous feature maps that come from the intermediate up sample layers.
If ``True``, in training mode, the forward function will output not only the final feature map
(from `output_block`), but also the feature maps that come from the intermediate up sample layers.
In order to unify the return type (the restriction of TorchScript), all intermediate
feature maps are interpolated into the same size as the last feature map and stacked together
feature maps are interpolated into the same size as the final feature map and stacked together
(with a new dimension in the first axis)into one single tensor.
For instance, if there are three feature maps with shapes: (1, 2, 32, 24), (1, 2, 16, 12) and
(1, 2, 8, 6). The last two will be interpolated into (1, 2, 32, 24), and the stacked tensor
will has the shape (1, 3, 2, 8, 6).
For instance, if there are two intermediate feature maps with shapes: (1, 2, 16, 12) and
(1, 2, 8, 6), and the final feature map has the shape (1, 2, 32, 24), then all intermediate feature maps
will be interpolated into (1, 2, 32, 24), and the stacked tensor will has the shape (1, 3, 2, 32, 24).
When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss
one by one with the ground truth, then do a weighted average for all losses to achieve the final loss.
(To be added: a corresponding tutorial link)
deep_supr_num: number of feature maps that will output during deep supervision head. The
value should be larger than 0 and less than the number of up sample layers.
Defaults to 1.
Expand Down Expand Up @@ -160,16 +160,17 @@ def __init__(
self.upsamples = self.get_upsamples()
self.output_block = self.get_output_block(0)
self.deep_supervision = deep_supervision
self.deep_supervision_heads = self.get_deep_supervision_heads()
self.deep_supr_num = deep_supr_num
# initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on
self.heads: List[torch.Tensor] = [torch.rand(1)] * self.deep_supr_num
if self.deep_supervision:
self.deep_supervision_heads = self.get_deep_supervision_heads()
self.check_deep_supr_num()

self.apply(self.initialize_weights)
self.check_kernel_stride()
self.check_deep_supr_num()

# initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on
self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1)

def create_skips(index, downsamples, upsamples, superheads, bottleneck):
def create_skips(index, downsamples, upsamples, bottleneck, superheads=None):
"""
Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is
done recursively from the top down since a recursive nn.Module subclass is being used to be compatible
Expand All @@ -180,30 +181,50 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck):

if len(downsamples) != len(upsamples):
raise ValueError(f"{len(downsamples)} != {len(upsamples)}")
if (len(downsamples) - len(superheads)) not in (1, 0):
raise ValueError(f"{len(downsamples)}-(0,1) != {len(superheads)}")

if len(downsamples) == 0: # bottom of the network, pass the bottleneck block
return bottleneck

if superheads is None:
next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck)
return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer)

super_head_flag = False
if index == 0: # don't associate a supervision head with self.input_block
current_head, rest_heads = nn.Identity(), superheads
elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one
current_head, rest_heads = nn.Identity(), superheads[1:]
rest_heads = superheads
else:
current_head, rest_heads = superheads[0], superheads[1:]
if len(superheads) > 0:
super_head_flag = True
rest_heads = superheads[1:]
else:
rest_heads = nn.ModuleList()

# create the next layer down, this will stop at the bottleneck layer
next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck)

return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer)

self.skip_layers = create_skips(
0,
[self.input_block] + list(self.downsamples),
self.upsamples[::-1],
self.deep_supervision_heads,
self.bottleneck,
)
next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck, superheads=rest_heads)
if super_head_flag:
return DynUNetSkipLayer(
index,
downsample=downsamples[0],
upsample=upsamples[0],
next_layer=next_layer,
heads=self.heads,
super_head=superheads[0],
)

return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer)

if not self.deep_supervision:
self.skip_layers = create_skips(
0, [self.input_block] + list(self.downsamples), self.upsamples[::-1], self.bottleneck
)
else:
self.skip_layers = create_skips(
0,
[self.input_block] + list(self.downsamples),
self.upsamples[::-1],
self.bottleneck,
superheads=self.deep_supervision_heads,
)

def check_kernel_stride(self):
kernels, strides = self.kernel_size, self.strides
Expand Down Expand Up @@ -242,8 +263,7 @@ def forward(self, x):
out = self.output_block(out)
if self.training and self.deep_supervision:
out_all = [out]
feature_maps = self.heads[1 : self.deep_supr_num + 1]
for feature_map in feature_maps:
for feature_map in self.heads:
out_all.append(interpolate(feature_map, out.shape[2:]))
return torch.stack(out_all, dim=1)
return out
Expand Down Expand Up @@ -334,7 +354,7 @@ def get_module_list(
return nn.ModuleList(layers)

def get_deep_supervision_heads(self):
return nn.ModuleList([self.get_output_block(i + 1) for i in range(len(self.upsamples) - 1)])
return nn.ModuleList([self.get_output_block(i + 1) for i in range(self.deep_supr_num)])

@staticmethod
def initialize_weights(module):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_network_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import monai.networks.nets as nets
from monai.utils import set_determinism

extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA", None)
extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA")

TESTS = []
if extra_test_data_dir is not None:
Expand Down Expand Up @@ -60,8 +60,8 @@ def test_network_consistency(self, net_name, data_path, json_path):
json_file.close()

# Create model
model = nets.__dict__[net_name](**model_params)
model.load_state_dict(loaded_data["model"])
model = getattr(nets, net_name)(**model_params)
model.load_state_dict(loaded_data["model"], strict=False)
model.eval()

in_data = loaded_data["in_data"]
Expand Down

0 comments on commit 98c1c43

Please sign in to comment.