Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ViT and Unetr to be torchscript comaptible #7937

Merged
merged 8 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion monai/networks/blocks/crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def __init__(
torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
)
self.causal_mask: torch.Tensor
else:
self.causal_mask = torch.Tensor()

self.att_mat = torch.Tensor()
self.rel_positional_embedding = (
Expand All @@ -118,7 +120,7 @@ def __init__(
)
self.input_size = input_size

def forward(self, x: torch.Tensor, context: torch.Tensor | None = None):
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
"""
Args:
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
Expand Down
2 changes: 2 additions & 0 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def __init__(
torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
)
self.causal_mask: torch.Tensor
else:
self.causal_mask = torch.Tensor()

self.rel_positional_embedding = (
get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.dim_head, self.num_heads)
Expand Down
13 changes: 7 additions & 6 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from __future__ import annotations

from typing import Optional

import torch
import torch.nn as nn

Expand Down Expand Up @@ -68,13 +70,12 @@ def __init__(
self.norm2 = nn.LayerNorm(hidden_size)
self.with_cross_attention = with_cross_attention

if self.with_cross_attention:
self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False
)
self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False
)

def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + self.attn(self.norm1(x))
if self.with_cross_attention:
x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
for mlp_dim in [3072]:
for num_layers in [4]:
for num_classes in [8]:
for pos_embed in ["conv", "perceptron"]:
for proj_type in ["conv", "perceptron"]:
for classification in [False, True]:
for nd in (2, 3):
test_case = [
Expand All @@ -42,7 +42,7 @@
"mlp_dim": mlp_dim,
"num_layers": num_layers,
"num_heads": num_heads,
"pos_embed": pos_embed,
"proj_type": proj_type,
"classification": classification,
"num_classes": num_classes,
"dropout_rate": dropout_rate,
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_ill_arg(
mlp_dim,
num_layers,
num_heads,
pos_embed,
proj_type,
classification,
dropout_rate,
):
Expand All @@ -100,12 +100,12 @@ def test_ill_arg(
mlp_dim=mlp_dim,
num_layers=num_layers,
num_heads=num_heads,
pos_embed=pos_embed,
proj_type=proj_type,
classification=classification,
dropout_rate=dropout_rate,
)

@parameterized.expand(TEST_CASE_Vit)
@parameterized.expand(TEST_CASE_Vit[:1])
@SkipIfBeforePyTorchVersion((1, 9))
def test_script(self, input_param, input_shape, _):
net = ViT(**(input_param))
Expand Down
Loading