Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

7994-enhance-mlpblock #7995

Merged
merged 8 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions monai/networks/blocks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import torch.nn as nn

from monai.networks.layers import get_act_layer
from monai.networks.layers.factories import split_args
from monai.utils import look_up_option

SUPPORTED_DROPOUT_MODE = {"vit", "swin"}
SUPPORTED_DROPOUT_MODE = {"vit", "swin", "vista3d"}


class MLPBlock(nn.Module):
Expand All @@ -39,7 +40,7 @@ def __init__(
https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87
"swin" corresponds to one instance as implemented in
https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23

"vista3d" mode does not use dropout.

"""

Expand All @@ -48,15 +49,20 @@ def __init__(
if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
mlp_dim = mlp_dim or hidden_size
self.linear1 = nn.Linear(hidden_size, mlp_dim) if act != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2)
act_name, _ = split_args(act)
self.linear1 = nn.Linear(hidden_size, mlp_dim) if act_name != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2)
self.linear2 = nn.Linear(mlp_dim, hidden_size)
self.fn = get_act_layer(act)
self.drop1 = nn.Dropout(dropout_rate)
dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE)
if dropout_opt == "vit":
self.drop1 = nn.Dropout(dropout_rate)
self.drop2 = nn.Dropout(dropout_rate)
elif dropout_opt == "swin":
self.drop1 = nn.Dropout(dropout_rate)
self.drop2 = self.drop1
elif dropout_opt == "vista3d":
self.drop1 = nn.Identity()
self.drop2 = nn.Identity()
else:
raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}")

Expand Down
28 changes: 28 additions & 0 deletions tests/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

import numpy as np
import torch
import torch.nn as nn
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.blocks.mlp import MLPBlock
from monai.networks.layers.factories import split_args

TEST_CASE_MLP = []
for dropout_rate in np.linspace(0, 1, 4):
Expand All @@ -31,6 +33,14 @@
]
TEST_CASE_MLP.append(test_case)

# test different activation layers
TEST_CASE_ACT = []
for act in ["GELU", "GEGLU", ("GELU", {"approximate": "tanh"}), ("GEGLU", {})]:
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
TEST_CASE_ACT.append([{"hidden_size": 128, "mlp_dim": 0, "act": act}, (2, 512, 128), (2, 512, 128)])

# test different dropout modes
TEST_CASE_DROP = [["vit", nn.Dropout], ["swin", nn.Dropout], ["vista3d", nn.Identity]]


class TestMLPBlock(unittest.TestCase):

Expand All @@ -45,6 +55,24 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=5.0)

@parameterized.expand(TEST_CASE_ACT)
def test_act(self, input_param, input_shape, expected_shape):
net = MLPBlock(**input_param)
with eval_mode(net):
result = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)
act_name, _ = split_args(input_param["act"])
if act_name == "GEGLU":
self.assertEqual(net.linear1.in_features, net.linear1.out_features // 2)
else:
self.assertEqual(net.linear1.in_features, net.linear1.out_features)

@parameterized.expand(TEST_CASE_DROP)
def test_dropout_mode(self, dropout_mode, dropout_layer):
net = MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=0.1, dropout_mode=dropout_mode)
self.assertTrue(isinstance(net.drop1, dropout_layer))
self.assertTrue(isinstance(net.drop2, dropout_layer))


if __name__ == "__main__":
unittest.main()
Loading