Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
KumoLiu committed Jul 22, 2024
1 parent d020fac commit 00108dc
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
1 change: 1 addition & 0 deletions monai/networks/blocks/crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
self.causal = causal
self.sequence_length = sequence_length

self.causal_mask = torch.Tensor()
if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer(
Expand Down
1 change: 1 addition & 0 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
self.causal = causal
self.sequence_length = sequence_length

self.causal_mask = torch.Tensor()
if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer(
Expand Down
9 changes: 4 additions & 5 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,10 @@ def __init__(
self.norm2 = nn.LayerNorm(hidden_size)
self.with_cross_attention = with_cross_attention

if self.with_cross_attention:
self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False
)
self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False
)

def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
x = x + self.attn(self.norm1(x))
Expand Down
11 changes: 6 additions & 5 deletions tests/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
for mlp_dim in [3072]:
for num_layers in [4]:
for num_classes in [8]:
for pos_embed in ["conv", "perceptron"]:
for proj_type in ["conv", "perceptron"]:
for classification in [False, True]:
for nd in (2, 3):
test_case = [
Expand All @@ -42,7 +42,7 @@
"mlp_dim": mlp_dim,
"num_layers": num_layers,
"num_heads": num_heads,
"pos_embed": pos_embed,
"proj_type": proj_type,
"classification": classification,
"num_classes": num_classes,
"dropout_rate": dropout_rate,
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_ill_arg(
mlp_dim,
num_layers,
num_heads,
pos_embed,
proj_type,
classification,
dropout_rate,
):
Expand All @@ -100,14 +100,15 @@ def test_ill_arg(
mlp_dim=mlp_dim,
num_layers=num_layers,
num_heads=num_heads,
pos_embed=pos_embed,
proj_type=proj_type,
classification=classification,
dropout_rate=dropout_rate,
)

@parameterized.expand(TEST_CASE_Vit)
@parameterized.expand(TEST_CASE_Vit[:1])
@SkipIfBeforePyTorchVersion((1, 9))
def test_script(self, input_param, input_shape, _):
print(input_param)
net = ViT(**(input_param))
net.eval()
with torch.no_grad():
Expand Down

0 comments on commit 00108dc

Please sign in to comment.