Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3293 Remove extra deep supervision modules of DynUNet #3427

Merged
merged 8 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 59 additions & 39 deletions monai/networks/nets/dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ class DynUNetSkipLayer(nn.Module):
forward passes of the network.
"""

heads: List[torch.Tensor]
heads: Optional[List[torch.Tensor]]

def __init__(self, index, heads, downsample, upsample, super_head, next_layer):
def __init__(self, index, downsample, upsample, next_layer, heads=None, super_head=None):
super().__init__()
self.downsample = downsample
self.upsample = upsample
self.next_layer = next_layer
self.upsample = upsample
self.super_head = super_head
self.heads = heads
self.index = index
Expand All @@ -46,8 +46,8 @@ def forward(self, x):
downout = self.downsample(x)
nextout = self.next_layer(downout)
upout = self.upsample(nextout, downout)

self.heads[self.index] = self.super_head(upout)
if self.super_head is not None and self.heads is not None and self.index > 0:
self.heads[self.index - 1] = self.super_head(upout)

return upout

Expand Down Expand Up @@ -79,6 +79,8 @@ class DynUNet(nn.Module):
For example, if `strides=((1, 2, 4), 2, 1, 1)`, the minimal spatial size of the input is `(8, 16, 32)`, and
the spatial size of the output is `(8, 8, 8)`.

For backwards compatibility with old weights, please set `strict=False` when calling `load_state_dict`.

Usage example with medical segmentation decathlon dataset is available at:
https://github.com/Project-MONAI/tutorials/tree/master/modules/dynunet_pipeline.

Expand All @@ -100,18 +102,16 @@ class DynUNet(nn.Module):
norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``.
act_name: activation layer type and arguments. Defaults to ``leakyrelu``.
deep_supervision: whether to add deep supervision head before output. Defaults to ``False``.
If ``True``, in training mode, the forward function will output not only the last feature
map, but also the previous feature maps that come from the intermediate up sample layers.
If ``True``, in training mode, the forward function will output not only the final feature map
(from `output_block`), but also the feature maps that come from the intermediate up sample layers.
In order to unify the return type (the restriction of TorchScript), all intermediate
feature maps are interpolated into the same size as the last feature map and stacked together
feature maps are interpolated into the same size as the final feature map and stacked together
(with a new dimension in the first axis)into one single tensor.
For instance, if there are three feature maps with shapes: (1, 2, 32, 24), (1, 2, 16, 12) and
(1, 2, 8, 6). The last two will be interpolated into (1, 2, 32, 24), and the stacked tensor
will has the shape (1, 3, 2, 8, 6).
For instance, if there are two intermediate feature maps with shapes: (1, 2, 16, 12) and
(1, 2, 8, 6), and the final feature map has the shape (1, 2, 32, 24), then all intermediate feature maps
will be interpolated into (1, 2, 32, 24), and the stacked tensor will has the shape (1, 3, 2, 32, 24).
When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss
one by one with the ground truth, then do a weighted average for all losses to achieve the final loss.
(To be added: a corresponding tutorial link)

deep_supr_num: number of feature maps that will output during deep supervision head. The
value should be larger than 0 and less than the number of up sample layers.
Defaults to 1.
Expand Down Expand Up @@ -160,16 +160,17 @@ def __init__(
self.upsamples = self.get_upsamples()
self.output_block = self.get_output_block(0)
self.deep_supervision = deep_supervision
self.deep_supervision_heads = self.get_deep_supervision_heads()
self.deep_supr_num = deep_supr_num
# initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on
self.heads: List[torch.Tensor] = [torch.rand(1)] * self.deep_supr_num
if self.deep_supervision:
self.deep_supervision_heads = self.get_deep_supervision_heads()
self.check_deep_supr_num()

self.apply(self.initialize_weights)
self.check_kernel_stride()
self.check_deep_supr_num()

# initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on
self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1)

def create_skips(index, downsamples, upsamples, superheads, bottleneck):
def create_skips(index, downsamples, upsamples, bottleneck, superheads=None):
"""
Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is
done recursively from the top down since a recursive nn.Module subclass is being used to be compatible
Expand All @@ -180,30 +181,50 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck):

if len(downsamples) != len(upsamples):
raise ValueError(f"{len(downsamples)} != {len(upsamples)}")
if (len(downsamples) - len(superheads)) not in (1, 0):
raise ValueError(f"{len(downsamples)}-(0,1) != {len(superheads)}")

if len(downsamples) == 0: # bottom of the network, pass the bottleneck block
return bottleneck

if superheads is None:
next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck)
return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer)

super_head_flag = False
if index == 0: # don't associate a supervision head with self.input_block
current_head, rest_heads = nn.Identity(), superheads
elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one
current_head, rest_heads = nn.Identity(), superheads[1:]
rest_heads = superheads
else:
current_head, rest_heads = superheads[0], superheads[1:]
if len(superheads) > 0:
super_head_flag = True
rest_heads = superheads[1:]
else:
rest_heads = nn.ModuleList()

# create the next layer down, this will stop at the bottleneck layer
next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck)

return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer)

self.skip_layers = create_skips(
0,
[self.input_block] + list(self.downsamples),
self.upsamples[::-1],
self.deep_supervision_heads,
self.bottleneck,
)
next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck, superheads=rest_heads)
if super_head_flag:
return DynUNetSkipLayer(
index,
downsample=downsamples[0],
upsample=upsamples[0],
next_layer=next_layer,
heads=self.heads,
super_head=superheads[0],
)

return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer)

if not self.deep_supervision:
yiheng-wang-nv marked this conversation as resolved.
Show resolved Hide resolved
self.skip_layers = create_skips(
0, [self.input_block] + list(self.downsamples), self.upsamples[::-1], self.bottleneck
)
else:
self.skip_layers = create_skips(
0,
[self.input_block] + list(self.downsamples),
self.upsamples[::-1],
self.bottleneck,
superheads=self.deep_supervision_heads,
)

def check_kernel_stride(self):
kernels, strides = self.kernel_size, self.strides
Expand Down Expand Up @@ -242,8 +263,7 @@ def forward(self, x):
out = self.output_block(out)
if self.training and self.deep_supervision:
out_all = [out]
feature_maps = self.heads[1 : self.deep_supr_num + 1]
for feature_map in feature_maps:
for feature_map in self.heads:
out_all.append(interpolate(feature_map, out.shape[2:]))
return torch.stack(out_all, dim=1)
return out
Expand Down Expand Up @@ -334,7 +354,7 @@ def get_module_list(
return nn.ModuleList(layers)

def get_deep_supervision_heads(self):
return nn.ModuleList([self.get_output_block(i + 1) for i in range(len(self.upsamples) - 1)])
return nn.ModuleList([self.get_output_block(i + 1) for i in range(self.deep_supr_num)])

@staticmethod
def initialize_weights(module):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_network_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import monai.networks.nets as nets
from monai.utils import set_determinism

extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA", None)
extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA")

TESTS = []
if extra_test_data_dir is not None:
Expand Down Expand Up @@ -60,8 +60,8 @@ def test_network_consistency(self, net_name, data_path, json_path):
json_file.close()

# Create model
model = nets.__dict__[net_name](**model_params)
model.load_state_dict(loaded_data["model"])
model = getattr(nets, net_name)(**model_params)
model.load_state_dict(loaded_data["model"], strict=False)
model.eval()

in_data = loaded_data["in_data"]
Expand Down