Skip to content

Commit

Permalink
feat(vnet): convert dropout_prob to a tuple (#6768)
Browse files Browse the repository at this point in the history
Fixes #6116 

### Description

Converts dropout probability from a Optional float to a Optional tuple
of floats.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Saurav Maheshkar <sauravvmaheshkar@gmail.com>
  • Loading branch information
SauravMaheshkar authored Jul 27, 2023
1 parent b02c37c commit 87d0ede
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions monai/networks/nets/vnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from monai.networks.blocks.convolutions import Convolution
from monai.networks.layers.factories import Act, Conv, Dropout, Norm, split_args
from monai.utils import deprecated_arg

__all__ = ["VNet"]

Expand Down Expand Up @@ -133,7 +134,7 @@ def __init__(
out_channels: int,
nconvs: int,
act: tuple[str, dict] | str,
dropout_prob: float | None = None,
dropout_prob: tuple[float | None, float] = (None, 0.5),
dropout_dim: int = 3,
):
super().__init__()
Expand All @@ -144,8 +145,8 @@ def __init__(

self.up_conv = conv_trans_type(in_channels, out_channels // 2, kernel_size=2, stride=2)
self.bn1 = norm_type(out_channels // 2)
self.dropout = dropout_type(dropout_prob) if dropout_prob is not None else None
self.dropout2 = dropout_type(0.5)
self.dropout = dropout_type(dropout_prob[0]) if dropout_prob[0] is not None else None
self.dropout2 = dropout_type(dropout_prob[1])
self.act_function1 = get_acti_layer(act, out_channels // 2)
self.act_function2 = get_acti_layer(act, out_channels)
self.ops = _make_nconv(spatial_dims, out_channels, nconvs, act)
Expand Down Expand Up @@ -206,8 +207,9 @@ class VNet(nn.Module):
The value should meet the condition that ``16 % in_channels == 0``.
out_channels: number of output channels for the network. Defaults to 1.
act: activation type in the network. Defaults to ``("elu", {"inplace": True})``.
dropout_prob: dropout ratio. Defaults to 0.5.
dropout_dim: determine the dimensions of dropout. Defaults to 3.
dropout_prob_down: dropout ratio for DownTransition blocks. Defaults to 0.5.
dropout_prob_up: dropout ratio for UpTransition blocks. Defaults to (0.5, 0.5).
dropout_dim: determine the dimensions of dropout. Defaults to (0.5, 0.5).
- ``dropout_dim = 1``, randomly zeroes some of the elements for each channel.
- ``dropout_dim = 2``, Randomly zeroes out entire channels (a channel is a 2D feature map).
Expand All @@ -216,15 +218,29 @@ class VNet(nn.Module):
According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
if a conv layer is directly followed by a batch norm layer, bias should be False.
.. deprecated:: 1.2
``dropout_prob`` is deprecated in favor of ``dropout_prob_down`` and ``dropout_prob_up``.
"""

@deprecated_arg(
name="dropout_prob",
since="1.2",
new_name="dropout_prob_down",
msg_suffix="please use `dropout_prob_down` instead.",
)
@deprecated_arg(
name="dropout_prob", since="1.2", new_name="dropout_prob_up", msg_suffix="please use `dropout_prob_up` instead."
)
def __init__(
self,
spatial_dims: int = 3,
in_channels: int = 1,
out_channels: int = 1,
act: tuple[str, dict] | str = ("elu", {"inplace": True}),
dropout_prob: float = 0.5,
dropout_prob: float | None = 0.5, # deprecated
dropout_prob_down: float | None = 0.5,
dropout_prob_up: tuple[float | None, float] = (0.5, 0.5),
dropout_dim: int = 3,
bias: bool = False,
):
Expand All @@ -236,10 +252,10 @@ def __init__(
self.in_tr = InputTransition(spatial_dims, in_channels, 16, act, bias=bias)
self.down_tr32 = DownTransition(spatial_dims, 16, 1, act, bias=bias)
self.down_tr64 = DownTransition(spatial_dims, 32, 2, act, bias=bias)
self.down_tr128 = DownTransition(spatial_dims, 64, 3, act, dropout_prob=dropout_prob, bias=bias)
self.down_tr256 = DownTransition(spatial_dims, 128, 2, act, dropout_prob=dropout_prob, bias=bias)
self.up_tr256 = UpTransition(spatial_dims, 256, 256, 2, act, dropout_prob=dropout_prob)
self.up_tr128 = UpTransition(spatial_dims, 256, 128, 2, act, dropout_prob=dropout_prob)
self.down_tr128 = DownTransition(spatial_dims, 64, 3, act, dropout_prob=dropout_prob_down, bias=bias)
self.down_tr256 = DownTransition(spatial_dims, 128, 2, act, dropout_prob=dropout_prob_down, bias=bias)
self.up_tr256 = UpTransition(spatial_dims, 256, 256, 2, act, dropout_prob=dropout_prob_up)
self.up_tr128 = UpTransition(spatial_dims, 256, 128, 2, act, dropout_prob=dropout_prob_up)
self.up_tr64 = UpTransition(spatial_dims, 128, 64, 1, act)
self.up_tr32 = UpTransition(spatial_dims, 64, 32, 1, act)
self.out_tr = OutputTransition(spatial_dims, 32, out_channels, act, bias=bias)
Expand Down

0 comments on commit 87d0ede

Please sign in to comment.