From 00108dc6f1ace9f8064105dde318ef3225cad7e3 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:52:42 +0800 Subject: [PATCH] fix #7936 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/crossattention.py | 1 + monai/networks/blocks/selfattention.py | 1 + monai/networks/blocks/transformerblock.py | 9 ++++----- tests/test_vit.py | 11 ++++++----- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index dc1d5d388e..320cf8e692 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -102,6 +102,7 @@ def __init__( self.causal = causal self.sequence_length = sequence_length + self.causal_mask = torch.Tensor() if causal and sequence_length is not None: # causal mask to ensure that attention is only applied to the left in the input sequence self.register_buffer( diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 9905e7d036..b15dc1bde5 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -98,6 +98,7 @@ def __init__( self.causal = causal self.sequence_length = sequence_length + self.causal_mask = torch.Tensor() if causal and sequence_length is not None: # causal mask to ensure that attention is only applied to the left in the input sequence self.register_buffer( diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 2458902cba..ce08c5a76f 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -68,11 +68,10 @@ def __init__( self.norm2 = nn.LayerNorm(hidden_size) self.with_cross_attention = with_cross_attention - if self.with_cross_attention: - self.norm_cross_attn = nn.LayerNorm(hidden_size) - self.cross_attn = CrossAttentionBlock( - hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False - ) + self.norm_cross_attn = nn.LayerNorm(hidden_size) + self.cross_attn = CrossAttentionBlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False + ) def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: x = x + self.attn(self.norm1(x)) diff --git a/tests/test_vit.py b/tests/test_vit.py index d27c10f95e..bf4e1715a6 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -30,7 +30,7 @@ for mlp_dim in [3072]: for num_layers in [4]: for num_classes in [8]: - for pos_embed in ["conv", "perceptron"]: + for proj_type in ["conv", "perceptron"]: for classification in [False, True]: for nd in (2, 3): test_case = [ @@ -42,7 +42,7 @@ "mlp_dim": mlp_dim, "num_layers": num_layers, "num_heads": num_heads, - "pos_embed": pos_embed, + "proj_type": proj_type, "classification": classification, "num_classes": num_classes, "dropout_rate": dropout_rate, @@ -87,7 +87,7 @@ def test_ill_arg( mlp_dim, num_layers, num_heads, - pos_embed, + proj_type, classification, dropout_rate, ): @@ -100,14 +100,15 @@ def test_ill_arg( mlp_dim=mlp_dim, num_layers=num_layers, num_heads=num_heads, - pos_embed=pos_embed, + proj_type=proj_type, classification=classification, dropout_rate=dropout_rate, ) - @parameterized.expand(TEST_CASE_Vit) + @parameterized.expand(TEST_CASE_Vit[:1]) @SkipIfBeforePyTorchVersion((1, 9)) def test_script(self, input_param, input_shape, _): + print(input_param) net = ViT(**(input_param)) net.eval() with torch.no_grad():