Skip to content

Commit

Permalink
3432 make vit support torchscript (#3782)
Browse files Browse the repository at this point in the history
* make vit support torchscript

Signed-off-by: Yiheng Wang <vennw@nvidia.com>

* add torch version restriction

Signed-off-by: Yiheng Wang <vennw@nvidia.com>

* change skip decorator order

Signed-off-by: Yiheng Wang <vennw@nvidia.com>

* remove extra cls

Signed-off-by: Yiheng Wang <vennw@nvidia.com>
  • Loading branch information
yiheng-wang-nv authored and wyli committed Feb 10, 2022
1 parent fc3eaa1 commit d55e922
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 7 deletions.
2 changes: 1 addition & 1 deletion monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
if self.pos_embed == "perceptron" and m % p != 0:
raise ValueError("patch_size should be divisible by img_size for perceptron.")
self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)])
self.patch_dim = in_channels * np.prod(patch_size)
self.patch_dim = int(in_channels * np.prod(patch_size))

self.patch_embeddings: nn.Module
if self.pos_embed == "conv":
Expand Down
9 changes: 6 additions & 3 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from monai.utils import optional_import

einops, _ = optional_import("einops")
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")


class SABlock(nn.Module):
Expand Down Expand Up @@ -43,17 +43,20 @@ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0)
self.num_heads = num_heads
self.out_proj = nn.Linear(hidden_size, hidden_size)
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=False)
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim**-0.5

def forward(self, x):
q, k, v = einops.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads)
output = self.input_rearrange(self.qkv(x))
q, k, v = output[0], output[1], output[2]
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = einops.rearrange(x, "b h l d -> b l (h d)")
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
return x
6 changes: 4 additions & 2 deletions monai/networks/nets/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class ViT(nn.Module):
"""
Vision Transformer (ViT), based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
ViT supports Torchscript but only works for Pytorch after 1.8.
"""

def __init__(
Expand Down Expand Up @@ -99,14 +101,14 @@ def __init__(

def forward(self, x):
x = self.patch_embedding(x)
if self.classification:
if hasattr(self, "cls_token"):
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
hidden_states_out = []
for blk in self.blocks:
x = blk(x)
hidden_states_out.append(x)
x = self.norm(x)
if self.classification:
if hasattr(self, "classification_head"):
x = self.classification_head(x[:, 0])
return x, hidden_states_out
14 changes: 13 additions & 1 deletion tests/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from monai.networks import eval_mode
from monai.networks.nets.vit import ViT
from tests.utils import SkipIfBeforePyTorchVersion, test_script_save

TEST_CASE_Vit = []
for dropout_rate in [0.6]:
Expand All @@ -27,7 +28,7 @@
for mlp_dim in [3072]:
for num_layers in [4]:
for num_classes in [8]:
for pos_embed in ["conv"]:
for pos_embed in ["conv", "perceptron"]:
for classification in [False, True]:
for nd in (2, 3):
test_case = [
Expand Down Expand Up @@ -133,6 +134,17 @@ def test_ill_arg(self):
dropout_rate=0.3,
)

@parameterized.expand(TEST_CASE_Vit)
@SkipIfBeforePyTorchVersion((1, 9))
def test_script(self, input_param, input_shape, _):
net = ViT(**(input_param))
net.eval()
with torch.no_grad():
torch.jit.script(net)

test_data = torch.randn(input_shape)
test_script_save(net, test_data)


if __name__ == "__main__":
unittest.main()

0 comments on commit d55e922

Please sign in to comment.