diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 4cd3046261..08938bb3bd 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -31,13 +31,13 @@ class DynUNetSkipLayer(nn.Module): forward passes of the network. """ - heads: List[torch.Tensor] + heads: Optional[List[torch.Tensor]] - def __init__(self, index, heads, downsample, upsample, super_head, next_layer): + def __init__(self, index, downsample, upsample, next_layer, heads=None, super_head=None): super().__init__() self.downsample = downsample - self.upsample = upsample self.next_layer = next_layer + self.upsample = upsample self.super_head = super_head self.heads = heads self.index = index @@ -46,8 +46,8 @@ def forward(self, x): downout = self.downsample(x) nextout = self.next_layer(downout) upout = self.upsample(nextout, downout) - - self.heads[self.index] = self.super_head(upout) + if self.super_head is not None and self.heads is not None and self.index > 0: + self.heads[self.index - 1] = self.super_head(upout) return upout @@ -79,6 +79,8 @@ class DynUNet(nn.Module): For example, if `strides=((1, 2, 4), 2, 1, 1)`, the minimal spatial size of the input is `(8, 16, 32)`, and the spatial size of the output is `(8, 8, 8)`. + For backwards compatibility with old weights, please set `strict=False` when calling `load_state_dict`. + Usage example with medical segmentation decathlon dataset is available at: https://github.com/Project-MONAI/tutorials/tree/master/modules/dynunet_pipeline. @@ -100,18 +102,16 @@ class DynUNet(nn.Module): norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``. act_name: activation layer type and arguments. Defaults to ``leakyrelu``. deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. - If ``True``, in training mode, the forward function will output not only the last feature - map, but also the previous feature maps that come from the intermediate up sample layers. + If ``True``, in training mode, the forward function will output not only the final feature map + (from `output_block`), but also the feature maps that come from the intermediate up sample layers. In order to unify the return type (the restriction of TorchScript), all intermediate - feature maps are interpolated into the same size as the last feature map and stacked together + feature maps are interpolated into the same size as the final feature map and stacked together (with a new dimension in the first axis)into one single tensor. - For instance, if there are three feature maps with shapes: (1, 2, 32, 24), (1, 2, 16, 12) and - (1, 2, 8, 6). The last two will be interpolated into (1, 2, 32, 24), and the stacked tensor - will has the shape (1, 3, 2, 8, 6). + For instance, if there are two intermediate feature maps with shapes: (1, 2, 16, 12) and + (1, 2, 8, 6), and the final feature map has the shape (1, 2, 32, 24), then all intermediate feature maps + will be interpolated into (1, 2, 32, 24), and the stacked tensor will has the shape (1, 3, 2, 32, 24). When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss one by one with the ground truth, then do a weighted average for all losses to achieve the final loss. - (To be added: a corresponding tutorial link) - deep_supr_num: number of feature maps that will output during deep supervision head. The value should be larger than 0 and less than the number of up sample layers. Defaults to 1. @@ -160,16 +160,17 @@ def __init__( self.upsamples = self.get_upsamples() self.output_block = self.get_output_block(0) self.deep_supervision = deep_supervision - self.deep_supervision_heads = self.get_deep_supervision_heads() self.deep_supr_num = deep_supr_num + # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on + self.heads: List[torch.Tensor] = [torch.rand(1)] * self.deep_supr_num + if self.deep_supervision: + self.deep_supervision_heads = self.get_deep_supervision_heads() + self.check_deep_supr_num() + self.apply(self.initialize_weights) self.check_kernel_stride() - self.check_deep_supr_num() - # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on - self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) - - def create_skips(index, downsamples, upsamples, superheads, bottleneck): + def create_skips(index, downsamples, upsamples, bottleneck, superheads=None): """ Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is done recursively from the top down since a recursive nn.Module subclass is being used to be compatible @@ -180,30 +181,50 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck): if len(downsamples) != len(upsamples): raise ValueError(f"{len(downsamples)} != {len(upsamples)}") - if (len(downsamples) - len(superheads)) not in (1, 0): - raise ValueError(f"{len(downsamples)}-(0,1) != {len(superheads)}") if len(downsamples) == 0: # bottom of the network, pass the bottleneck block return bottleneck + + if superheads is None: + next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck) + return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer) + + super_head_flag = False if index == 0: # don't associate a supervision head with self.input_block - current_head, rest_heads = nn.Identity(), superheads - elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one - current_head, rest_heads = nn.Identity(), superheads[1:] + rest_heads = superheads else: - current_head, rest_heads = superheads[0], superheads[1:] + if len(superheads) > 0: + super_head_flag = True + rest_heads = superheads[1:] + else: + rest_heads = nn.ModuleList() # create the next layer down, this will stop at the bottleneck layer - next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck) - - return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer) - - self.skip_layers = create_skips( - 0, - [self.input_block] + list(self.downsamples), - self.upsamples[::-1], - self.deep_supervision_heads, - self.bottleneck, - ) + next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck, superheads=rest_heads) + if super_head_flag: + return DynUNetSkipLayer( + index, + downsample=downsamples[0], + upsample=upsamples[0], + next_layer=next_layer, + heads=self.heads, + super_head=superheads[0], + ) + + return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer) + + if not self.deep_supervision: + self.skip_layers = create_skips( + 0, [self.input_block] + list(self.downsamples), self.upsamples[::-1], self.bottleneck + ) + else: + self.skip_layers = create_skips( + 0, + [self.input_block] + list(self.downsamples), + self.upsamples[::-1], + self.bottleneck, + superheads=self.deep_supervision_heads, + ) def check_kernel_stride(self): kernels, strides = self.kernel_size, self.strides @@ -242,8 +263,7 @@ def forward(self, x): out = self.output_block(out) if self.training and self.deep_supervision: out_all = [out] - feature_maps = self.heads[1 : self.deep_supr_num + 1] - for feature_map in feature_maps: + for feature_map in self.heads: out_all.append(interpolate(feature_map, out.shape[2:])) return torch.stack(out_all, dim=1) return out @@ -334,7 +354,7 @@ def get_module_list( return nn.ModuleList(layers) def get_deep_supervision_heads(self): - return nn.ModuleList([self.get_output_block(i + 1) for i in range(len(self.upsamples) - 1)]) + return nn.ModuleList([self.get_output_block(i + 1) for i in range(self.deep_supr_num)]) @staticmethod def initialize_weights(module): diff --git a/tests/test_network_consistency.py b/tests/test_network_consistency.py index ccccd9e7f0..92cb7a0595 100644 --- a/tests/test_network_consistency.py +++ b/tests/test_network_consistency.py @@ -22,7 +22,7 @@ import monai.networks.nets as nets from monai.utils import set_determinism -extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA", None) +extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA") TESTS = [] if extra_test_data_dir is not None: @@ -60,8 +60,8 @@ def test_network_consistency(self, net_name, data_path, json_path): json_file.close() # Create model - model = nets.__dict__[net_name](**model_params) - model.load_state_dict(loaded_data["model"]) + model = getattr(nets, net_name)(**model_params) + model.load_state_dict(loaded_data["model"], strict=False) model.eval() in_data = loaded_data["in_data"]