diff --git a/monai/networks/nets/vnet.py b/monai/networks/nets/vnet.py index 697547093a..d89eb8ae03 100644 --- a/monai/networks/nets/vnet.py +++ b/monai/networks/nets/vnet.py @@ -16,6 +16,7 @@ from monai.networks.blocks.convolutions import Convolution from monai.networks.layers.factories import Act, Conv, Dropout, Norm, split_args +from monai.utils import deprecated_arg __all__ = ["VNet"] @@ -133,7 +134,7 @@ def __init__( out_channels: int, nconvs: int, act: tuple[str, dict] | str, - dropout_prob: float | None = None, + dropout_prob: tuple[float | None, float] = (None, 0.5), dropout_dim: int = 3, ): super().__init__() @@ -144,8 +145,8 @@ def __init__( self.up_conv = conv_trans_type(in_channels, out_channels // 2, kernel_size=2, stride=2) self.bn1 = norm_type(out_channels // 2) - self.dropout = dropout_type(dropout_prob) if dropout_prob is not None else None - self.dropout2 = dropout_type(0.5) + self.dropout = dropout_type(dropout_prob[0]) if dropout_prob[0] is not None else None + self.dropout2 = dropout_type(dropout_prob[1]) self.act_function1 = get_acti_layer(act, out_channels // 2) self.act_function2 = get_acti_layer(act, out_channels) self.ops = _make_nconv(spatial_dims, out_channels, nconvs, act) @@ -206,8 +207,9 @@ class VNet(nn.Module): The value should meet the condition that ``16 % in_channels == 0``. out_channels: number of output channels for the network. Defaults to 1. act: activation type in the network. Defaults to ``("elu", {"inplace": True})``. - dropout_prob: dropout ratio. Defaults to 0.5. - dropout_dim: determine the dimensions of dropout. Defaults to 3. + dropout_prob_down: dropout ratio for DownTransition blocks. Defaults to 0.5. + dropout_prob_up: dropout ratio for UpTransition blocks. Defaults to (0.5, 0.5). + dropout_dim: determine the dimensions of dropout. Defaults to (0.5, 0.5). - ``dropout_dim = 1``, randomly zeroes some of the elements for each channel. - ``dropout_dim = 2``, Randomly zeroes out entire channels (a channel is a 2D feature map). @@ -216,15 +218,29 @@ class VNet(nn.Module): According to `Performance Tuning Guide `_, if a conv layer is directly followed by a batch norm layer, bias should be False. + .. deprecated:: 1.2 + ``dropout_prob`` is deprecated in favor of ``dropout_prob_down`` and ``dropout_prob_up``. + """ + @deprecated_arg( + name="dropout_prob", + since="1.2", + new_name="dropout_prob_down", + msg_suffix="please use `dropout_prob_down` instead.", + ) + @deprecated_arg( + name="dropout_prob", since="1.2", new_name="dropout_prob_up", msg_suffix="please use `dropout_prob_up` instead." + ) def __init__( self, spatial_dims: int = 3, in_channels: int = 1, out_channels: int = 1, act: tuple[str, dict] | str = ("elu", {"inplace": True}), - dropout_prob: float = 0.5, + dropout_prob: float | None = 0.5, # deprecated + dropout_prob_down: float | None = 0.5, + dropout_prob_up: tuple[float | None, float] = (0.5, 0.5), dropout_dim: int = 3, bias: bool = False, ): @@ -236,10 +252,10 @@ def __init__( self.in_tr = InputTransition(spatial_dims, in_channels, 16, act, bias=bias) self.down_tr32 = DownTransition(spatial_dims, 16, 1, act, bias=bias) self.down_tr64 = DownTransition(spatial_dims, 32, 2, act, bias=bias) - self.down_tr128 = DownTransition(spatial_dims, 64, 3, act, dropout_prob=dropout_prob, bias=bias) - self.down_tr256 = DownTransition(spatial_dims, 128, 2, act, dropout_prob=dropout_prob, bias=bias) - self.up_tr256 = UpTransition(spatial_dims, 256, 256, 2, act, dropout_prob=dropout_prob) - self.up_tr128 = UpTransition(spatial_dims, 256, 128, 2, act, dropout_prob=dropout_prob) + self.down_tr128 = DownTransition(spatial_dims, 64, 3, act, dropout_prob=dropout_prob_down, bias=bias) + self.down_tr256 = DownTransition(spatial_dims, 128, 2, act, dropout_prob=dropout_prob_down, bias=bias) + self.up_tr256 = UpTransition(spatial_dims, 256, 256, 2, act, dropout_prob=dropout_prob_up) + self.up_tr128 = UpTransition(spatial_dims, 256, 128, 2, act, dropout_prob=dropout_prob_up) self.up_tr64 = UpTransition(spatial_dims, 128, 64, 1, act) self.up_tr32 = UpTransition(spatial_dims, 64, 32, 1, act) self.out_tr = OutputTransition(spatial_dims, 32, out_channels, act, bias=bias)