diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index dc1d5d388e..b888ea3942 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -109,6 +109,8 @@ def __init__( torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), ) self.causal_mask: torch.Tensor + else: + self.causal_mask = torch.Tensor() self.att_mat = torch.Tensor() self.rel_positional_embedding = ( @@ -118,7 +120,7 @@ def __init__( ) self.input_size = input_size - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None): + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): """ Args: x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 9905e7d036..3ab1e1fd10 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -105,6 +105,8 @@ def __init__( torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), ) self.causal_mask: torch.Tensor + else: + self.causal_mask = torch.Tensor() self.rel_positional_embedding = ( get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.dim_head, self.num_heads) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 2458902cba..0aa1697479 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -11,6 +11,8 @@ from __future__ import annotations +from typing import Optional + import torch import torch.nn as nn @@ -68,13 +70,12 @@ def __init__( self.norm2 = nn.LayerNorm(hidden_size) self.with_cross_attention = with_cross_attention - if self.with_cross_attention: - self.norm_cross_attn = nn.LayerNorm(hidden_size) - self.cross_attn = CrossAttentionBlock( - hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False - ) + self.norm_cross_attn = nn.LayerNorm(hidden_size) + self.cross_attn = CrossAttentionBlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False + ) - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: x = x + self.attn(self.norm1(x)) if self.with_cross_attention: x = x + self.cross_attn(self.norm_cross_attn(x), context=context) diff --git a/tests/test_vit.py b/tests/test_vit.py index d27c10f95e..d638c0116a 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -30,7 +30,7 @@ for mlp_dim in [3072]: for num_layers in [4]: for num_classes in [8]: - for pos_embed in ["conv", "perceptron"]: + for proj_type in ["conv", "perceptron"]: for classification in [False, True]: for nd in (2, 3): test_case = [ @@ -42,7 +42,7 @@ "mlp_dim": mlp_dim, "num_layers": num_layers, "num_heads": num_heads, - "pos_embed": pos_embed, + "proj_type": proj_type, "classification": classification, "num_classes": num_classes, "dropout_rate": dropout_rate, @@ -87,7 +87,7 @@ def test_ill_arg( mlp_dim, num_layers, num_heads, - pos_embed, + proj_type, classification, dropout_rate, ): @@ -100,12 +100,12 @@ def test_ill_arg( mlp_dim=mlp_dim, num_layers=num_layers, num_heads=num_heads, - pos_embed=pos_embed, + proj_type=proj_type, classification=classification, dropout_rate=dropout_rate, ) - @parameterized.expand(TEST_CASE_Vit) + @parameterized.expand(TEST_CASE_Vit[:1]) @SkipIfBeforePyTorchVersion((1, 9)) def test_script(self, input_param, input_shape, _): net = ViT(**(input_param))