Skip to content

Commit

Permalink
7994-enhance-mlpblock (#7995)
Browse files Browse the repository at this point in the history
Fixes #7994  .

### Description
The current implementation does not support tuple input of "GEGLU" since
it only change the out features of the first linear layer when the input
is a string of "GEGLU".

This PR enhances it, and also enable "vista3d" mode to support #7987 
Tests are added to cover the changes.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Yiheng Wang <vennw@nvidia.com>
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
yiheng-wang-nv and KumoLiu authored Aug 7, 2024
1 parent 6c23fd0 commit 49a1e34
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 4 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ jobs:
name: Install itk pre-release (Linux only)
run: |
python -m pip install --pre -U itk
find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
- name: Install the dependencies
run: |
python -m pip install --user --upgrade pip wheel
Expand Down
20 changes: 16 additions & 4 deletions monai/networks/blocks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@

from __future__ import annotations

from typing import Union

import torch.nn as nn

from monai.networks.layers import get_act_layer
from monai.networks.layers.factories import split_args
from monai.utils import look_up_option

SUPPORTED_DROPOUT_MODE = {"vit", "swin"}
SUPPORTED_DROPOUT_MODE = {"vit", "swin", "vista3d"}


class MLPBlock(nn.Module):
Expand All @@ -39,7 +42,7 @@ def __init__(
https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87
"swin" corresponds to one instance as implemented in
https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23
"vista3d" mode does not use dropout.
"""

Expand All @@ -48,15 +51,24 @@ def __init__(
if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
mlp_dim = mlp_dim or hidden_size
self.linear1 = nn.Linear(hidden_size, mlp_dim) if act != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2)
act_name, _ = split_args(act)
self.linear1 = nn.Linear(hidden_size, mlp_dim) if act_name != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2)
self.linear2 = nn.Linear(mlp_dim, hidden_size)
self.fn = get_act_layer(act)
self.drop1 = nn.Dropout(dropout_rate)
# Use Union[nn.Dropout, nn.Identity] for type annotations
self.drop1: Union[nn.Dropout, nn.Identity]
self.drop2: Union[nn.Dropout, nn.Identity]

dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE)
if dropout_opt == "vit":
self.drop1 = nn.Dropout(dropout_rate)
self.drop2 = nn.Dropout(dropout_rate)
elif dropout_opt == "swin":
self.drop1 = nn.Dropout(dropout_rate)
self.drop2 = self.drop1
elif dropout_opt == "vista3d":
self.drop1 = nn.Identity()
self.drop2 = nn.Identity()
else:
raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}")

Expand Down
28 changes: 28 additions & 0 deletions tests/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

import numpy as np
import torch
import torch.nn as nn
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.blocks.mlp import MLPBlock
from monai.networks.layers.factories import split_args

TEST_CASE_MLP = []
for dropout_rate in np.linspace(0, 1, 4):
Expand All @@ -31,6 +33,14 @@
]
TEST_CASE_MLP.append(test_case)

# test different activation layers
TEST_CASE_ACT = []
for act in ["GELU", "GEGLU", ("GEGLU", {})]: # type: ignore
TEST_CASE_ACT.append([{"hidden_size": 128, "mlp_dim": 0, "act": act}, (2, 512, 128), (2, 512, 128)])

# test different dropout modes
TEST_CASE_DROP = [["vit", nn.Dropout], ["swin", nn.Dropout], ["vista3d", nn.Identity]]


class TestMLPBlock(unittest.TestCase):

Expand All @@ -45,6 +55,24 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=5.0)

@parameterized.expand(TEST_CASE_ACT)
def test_act(self, input_param, input_shape, expected_shape):
net = MLPBlock(**input_param)
with eval_mode(net):
result = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)
act_name, _ = split_args(input_param["act"])
if act_name == "GEGLU":
self.assertEqual(net.linear1.in_features, net.linear1.out_features // 2)
else:
self.assertEqual(net.linear1.in_features, net.linear1.out_features)

@parameterized.expand(TEST_CASE_DROP)
def test_dropout_mode(self, dropout_mode, dropout_layer):
net = MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=0.1, dropout_mode=dropout_mode)
self.assertTrue(isinstance(net.drop1, dropout_layer))
self.assertTrue(isinstance(net.drop2, dropout_layer))


if __name__ == "__main__":
unittest.main()

0 comments on commit 49a1e34

Please sign in to comment.