Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

7994-enhance-mlpblock #7995

Merged
merged 8 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ jobs:
name: Install itk pre-release (Linux only)
run: |
python -m pip install --pre -U itk
find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
- name: Install the dependencies
run: |
python -m pip install --user --upgrade pip wheel
Expand Down
20 changes: 16 additions & 4 deletions monai/networks/blocks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@

from __future__ import annotations

from typing import Union

import torch.nn as nn

from monai.networks.layers import get_act_layer
from monai.networks.layers.factories import split_args
from monai.utils import look_up_option

SUPPORTED_DROPOUT_MODE = {"vit", "swin"}
SUPPORTED_DROPOUT_MODE = {"vit", "swin", "vista3d"}


class MLPBlock(nn.Module):
Expand All @@ -39,7 +42,7 @@ def __init__(
https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87
"swin" corresponds to one instance as implemented in
https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23

"vista3d" mode does not use dropout.

"""

Expand All @@ -48,15 +51,24 @@ def __init__(
if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
mlp_dim = mlp_dim or hidden_size
self.linear1 = nn.Linear(hidden_size, mlp_dim) if act != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2)
act_name, _ = split_args(act)
self.linear1 = nn.Linear(hidden_size, mlp_dim) if act_name != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2)
self.linear2 = nn.Linear(mlp_dim, hidden_size)
self.fn = get_act_layer(act)
self.drop1 = nn.Dropout(dropout_rate)
# Use Union[nn.Dropout, nn.Identity] for type annotations
self.drop1: Union[nn.Dropout, nn.Identity]
self.drop2: Union[nn.Dropout, nn.Identity]

dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE)
if dropout_opt == "vit":
self.drop1 = nn.Dropout(dropout_rate)
self.drop2 = nn.Dropout(dropout_rate)
elif dropout_opt == "swin":
self.drop1 = nn.Dropout(dropout_rate)
self.drop2 = self.drop1
elif dropout_opt == "vista3d":
self.drop1 = nn.Identity()
self.drop2 = nn.Identity()
else:
raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}")

Expand Down
28 changes: 28 additions & 0 deletions tests/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

import numpy as np
import torch
import torch.nn as nn
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.blocks.mlp import MLPBlock
from monai.networks.layers.factories import split_args

TEST_CASE_MLP = []
for dropout_rate in np.linspace(0, 1, 4):
Expand All @@ -31,6 +33,14 @@
]
TEST_CASE_MLP.append(test_case)

# test different activation layers
TEST_CASE_ACT = []
for act in ["GELU", "GEGLU", ("GEGLU", {})]: # type: ignore
TEST_CASE_ACT.append([{"hidden_size": 128, "mlp_dim": 0, "act": act}, (2, 512, 128), (2, 512, 128)])

# test different dropout modes
TEST_CASE_DROP = [["vit", nn.Dropout], ["swin", nn.Dropout], ["vista3d", nn.Identity]]


class TestMLPBlock(unittest.TestCase):

Expand All @@ -45,6 +55,24 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=5.0)

@parameterized.expand(TEST_CASE_ACT)
def test_act(self, input_param, input_shape, expected_shape):
net = MLPBlock(**input_param)
with eval_mode(net):
result = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)
act_name, _ = split_args(input_param["act"])
if act_name == "GEGLU":
self.assertEqual(net.linear1.in_features, net.linear1.out_features // 2)
else:
self.assertEqual(net.linear1.in_features, net.linear1.out_features)

@parameterized.expand(TEST_CASE_DROP)
def test_dropout_mode(self, dropout_mode, dropout_layer):
net = MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=0.1, dropout_mode=dropout_mode)
self.assertTrue(isinstance(net.drop1, dropout_layer))
self.assertTrue(isinstance(net.drop2, dropout_layer))


if __name__ == "__main__":
unittest.main()
Loading