From 4cd46deb49370c9b7f11dd3448e2b61aaf1d19ef Mon Sep 17 00:00:00 2001 From: yang-ze-kang <603822317@qq.com> Date: Thu, 25 Sep 2025 12:07:37 +0800 Subject: [PATCH 1/2] fix transformerblock Signed-off-by: yang-ze-kang <603822317@qq.com> --- monai/networks/blocks/transformerblock.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 6f0da73e7b..9562125272 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -80,15 +80,16 @@ def __init__( self.norm2 = nn.LayerNorm(hidden_size) self.with_cross_attention = with_cross_attention - self.norm_cross_attn = nn.LayerNorm(hidden_size) - self.cross_attn = CrossAttentionBlock( - hidden_size=hidden_size, - num_heads=num_heads, - dropout_rate=dropout_rate, - qkv_bias=qkv_bias, - causal=False, - use_flash_attention=use_flash_attention, - ) + if with_cross_attention: + self.norm_cross_attn = nn.LayerNorm(hidden_size) + self.cross_attn = CrossAttentionBlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + causal=False, + use_flash_attention=use_flash_attention, + ) def forward( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None From 2b1a37c05b0504b603e07193108ab325d7e8b52b Mon Sep 17 00:00:00 2001 From: Yang Zekang Date: Thu, 25 Sep 2025 12:39:55 +0800 Subject: [PATCH 2/2] Update monai/networks/blocks/transformerblock.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Yang Zekang --- monai/networks/blocks/transformerblock.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 9562125272..2cdb693ae2 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -90,6 +90,12 @@ def __init__( causal=False, use_flash_attention=use_flash_attention, ) + else: + def _drop_cross_attn_keys(state_dict, prefix, *_args): + for key in list(state_dict.keys()): + if key.startswith(prefix + "cross_attn.") or key.startswith(prefix + "norm_cross_attn."): + state_dict.pop(key) + self._register_load_state_dict_pre_hook(_drop_cross_attn_keys) def forward( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None