diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 6f0da73e7b..2cdb693ae2 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -80,15 +80,22 @@ def __init__( self.norm2 = nn.LayerNorm(hidden_size) self.with_cross_attention = with_cross_attention - self.norm_cross_attn = nn.LayerNorm(hidden_size) - self.cross_attn = CrossAttentionBlock( - hidden_size=hidden_size, - num_heads=num_heads, - dropout_rate=dropout_rate, - qkv_bias=qkv_bias, - causal=False, - use_flash_attention=use_flash_attention, - ) + if with_cross_attention: + self.norm_cross_attn = nn.LayerNorm(hidden_size) + self.cross_attn = CrossAttentionBlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + causal=False, + use_flash_attention=use_flash_attention, + ) + else: + def _drop_cross_attn_keys(state_dict, prefix, *_args): + for key in list(state_dict.keys()): + if key.startswith(prefix + "cross_attn.") or key.startswith(prefix + "norm_cross_attn."): + state_dict.pop(key) + self._register_load_state_dict_pre_hook(_drop_cross_attn_keys) def forward( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None