Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add include_fc and use_combined_linear argument in the SABlock #7996

Merged
merged 36 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
37cd5cd
fix #7991
KumoLiu Aug 6, 2024
63ba16d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
7255a90
add docstring
KumoLiu Aug 6, 2024
ddbd336
Merge branch 'proj-atten' of https://github.com/KumoLiu/MONAI into pr…
KumoLiu Aug 6, 2024
0337d45
fix #7992
KumoLiu Aug 6, 2024
f198e2c
Merge branch 'linear' into proj-atten
KumoLiu Aug 6, 2024
814e61a
add tests
KumoLiu Aug 6, 2024
7dd22e0
Merge remote-tracking branch 'origin/dev' into proj-atten
KumoLiu Aug 6, 2024
de9eef0
remove transpose in sablock
KumoLiu Aug 7, 2024
2333351
fix docstring
KumoLiu Aug 7, 2024
f9eb6d8
use rearange
KumoLiu Aug 7, 2024
5aeccbe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2024
3154c7c
minor fix
KumoLiu Aug 7, 2024
81d3605
add in SpatialAttentionBlock
KumoLiu Aug 7, 2024
f47c2c6
Merge remote-tracking branch 'origin/dev' into proj-atten
KumoLiu Aug 7, 2024
754e7f2
fix format
KumoLiu Aug 7, 2024
3cf2124
add tests
KumoLiu Aug 7, 2024
8de91eb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2024
0b556a5
minor fix
KumoLiu Aug 7, 2024
9a59a15
minor fix
KumoLiu Aug 7, 2024
05e42ce
format fix
KumoLiu Aug 7, 2024
7dc0933
Merge branch 'proj-atten' of https://github.com/KumoLiu/MONAI into pr…
KumoLiu Aug 7, 2024
aae275d
minor fix
KumoLiu Aug 7, 2024
531a831
fix mypy
KumoLiu Aug 7, 2024
48319c0
fix ci
KumoLiu Aug 7, 2024
b854d7a
minor fix
KumoLiu Aug 7, 2024
32d0a5d
address comments
KumoLiu Aug 8, 2024
e5f2cb1
minor fix
KumoLiu Aug 8, 2024
818ba7e
Update tests/test_crossattention.py
KumoLiu Aug 9, 2024
4bef7f0
Update tests/test_selfattention.py
KumoLiu Aug 9, 2024
bfc8f29
minor fix
KumoLiu Aug 9, 2024
3d09b4a
Merge remote-tracking branch 'origin/dev' into proj-atten
KumoLiu Aug 9, 2024
0da115a
address comments
KumoLiu Aug 9, 2024
0d46a6b
Merge branch 'dev' into proj-atten
KumoLiu Aug 9, 2024
1c5599d
fix state dict
KumoLiu Aug 9, 2024
6ed765d
Merge branch 'dev' into proj-atten
KumoLiu Aug 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 11 additions & 22 deletions monai/networks/blocks/crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,12 @@ def __init__(
causal (bool, optional): whether to use causal attention.
sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length.
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only
"decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
"decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional
parameter size.
parameter size.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use Pytorch's inbuilt
flash attention for a memory efficient attention mechanism (see
https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""

super().__init__()
Expand Down Expand Up @@ -109,7 +108,7 @@ def __init__(
self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)

self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.out_rearrange = Rearrange("b l h d -> b h (l d)")
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
self.dropout_rate = dropout_rate
Expand Down Expand Up @@ -152,31 +151,20 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size)

q = self.to_q(x)
q = self.input_rearrange(self.to_q(x))
kv = context if context is not None else x
_, kv_t, _ = kv.size()
k = self.to_k(kv)
v = self.to_v(kv)
k = self.input_rearrange(self.to_k(kv))
v = self.input_rearrange(self.to_v(kv))

if self.attention_dtype is not None:
q = q.to(self.attention_dtype)
k = k.to(self.attention_dtype)

q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) #
k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)

if self.use_flash_attention:
x = torch.nn.functional.scaled_dot_product_attention(
query=q.transpose(1, 2),
key=k.transpose(1, 2),
value=v.transpose(1, 2),
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
).transpose(
1, 2
) # Back to (b, nh, t, hs)
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
# apply relative positional embedding if defined
Expand All @@ -195,6 +183,7 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)

x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
Expand Down
60 changes: 40 additions & 20 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from __future__ import annotations

from typing import Optional, Tuple
from typing import Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -40,9 +40,11 @@ def __init__(
hidden_input_size: int | None = None,
causal: bool = False,
sequence_length: int | None = None,
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
attention_dtype: Optional[torch.dtype] = None,
rel_pos_embedding: str | None = None,
input_size: Tuple | None = None,
attention_dtype: torch.dtype | None = None,
include_fc: bool = True,
use_combined_linear: bool = True,
use_flash_attention: bool = False,
) -> None:
"""
Expand All @@ -61,9 +63,10 @@ def __init__(
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use Pytorch's inbuilt
flash attention for a memory efficient attention mechanism (see
https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
include_fc: whether to include the final linear layer. Default to True.
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""

Expand Down Expand Up @@ -105,9 +108,22 @@ def __init__(
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)

self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias)
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.qkv: Union[nn.Linear, nn.Identity]
self.to_q: Union[nn.Linear, nn.Identity]
self.to_k: Union[nn.Linear, nn.Identity]
self.to_v: Union[nn.Linear, nn.Identity]

if use_combined_linear:
self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias)
self.to_q = self.to_k = self.to_v = nn.Identity() # add to enable torchscript
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
else:
self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
self.qkv = nn.Identity() # add to enable torchscript
self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)
self.out_rearrange = Rearrange("b l h d -> b h (l d)")
ericspod marked this conversation as resolved.
Show resolved Hide resolved
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
self.dropout_rate = dropout_rate
Expand All @@ -117,6 +133,8 @@ def __init__(
self.attention_dtype = attention_dtype
self.causal = causal
self.sequence_length = sequence_length
self.include_fc = include_fc
self.use_combined_linear = use_combined_linear
self.use_flash_attention = use_flash_attention

if causal and sequence_length is not None:
Expand Down Expand Up @@ -144,22 +162,22 @@ def forward(self, x):
Return:
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
"""
output = self.input_rearrange(self.qkv(x))
q, k, v = output[0], output[1], output[2]
if self.use_combined_linear:
output = self.input_rearrange(self.qkv(x))
q, k, v = output[0], output[1], output[2]
else:
q = self.input_rearrange(self.to_q(x))
k = self.input_rearrange(self.to_k(x))
v = self.input_rearrange(self.to_v(x))

if self.attention_dtype is not None:
q = q.to(self.attention_dtype)
k = k.to(self.attention_dtype)

if self.use_flash_attention:
x = F.scaled_dot_product_attention(
query=q.transpose(1, 2),
key=k.transpose(1, 2),
value=v.transpose(1, 2),
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
).transpose(1, 2)
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

Expand All @@ -179,7 +197,9 @@ def forward(self, x):

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)

x = self.out_rearrange(x)
x = self.out_proj(x)
if self.include_fc:
x = self.out_proj(x)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
x = self.drop_output(x)
return x
11 changes: 10 additions & 1 deletion monai/networks/blocks/spatialattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ class SpatialAttentionBlock(nn.Module):
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
num_channels: number of input channels. Must be divisible by num_head_channels.
num_head_channels: number of channels per head.
norm_num_groups: Number of groups for the group norm layer.
norm_eps: Epsilon for the normalization.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
include_fc: whether to include the final linear layer. Default to True.
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).

"""

Expand All @@ -45,6 +50,8 @@ def __init__(
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
attention_dtype: Optional[torch.dtype] = None,
include_fc: bool = True,
use_combined_linear: bool = False,
use_flash_attention: bool = False,
) -> None:
super().__init__()
Expand All @@ -60,6 +67,8 @@ def __init__(
num_heads=num_heads,
qkv_bias=True,
attention_dtype=attention_dtype,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)

Expand Down
8 changes: 7 additions & 1 deletion monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(
sequence_length: int | None = None,
with_cross_attention: bool = False,
use_flash_attention: bool = False,
include_fc: bool = True,
use_combined_linear: bool = True,
) -> None:
"""
Args:
Expand All @@ -47,7 +49,9 @@ def __init__(
qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
include_fc: whether to include the final linear layer. Default to True.
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
"""

Expand All @@ -69,6 +73,8 @@ def __init__(
save_attn=save_attn,
causal=causal,
sequence_length=sequence_length,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)
self.norm2 = nn.LayerNorm(hidden_size)
Expand Down
8 changes: 7 additions & 1 deletion monai/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ class DiffusionUNetTransformerBlock(nn.Module):
cross_attention_dim: size of the context vector for cross attention.
upcast_attention: if True, upcast attention operations to full precision.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
include_fc: whether to include the final linear layer. Default to True.
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.

"""

Expand All @@ -80,6 +82,8 @@ def __init__(
cross_attention_dim: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
include_fc: bool = True,
use_combined_linear: bool = False,
) -> None:
super().__init__()
self.attn1 = SABlock(
Expand All @@ -89,6 +93,8 @@ def __init__(
dim_head=num_head_channels,
dropout_rate=dropout,
attention_dtype=torch.float if upcast_attention else None,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)
self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout)
Expand Down
22 changes: 21 additions & 1 deletion tests/test_crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from monai.networks.blocks.crossattention import CrossAttentionBlock
from monai.networks.layers.factories import RelPosEmbedding
from monai.utils import optional_import
from tests.utils import SkipIfBeforePyTorchVersion
from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose

einops, has_einops = optional_import("einops")

Expand Down Expand Up @@ -166,6 +166,26 @@ def test_access_attn_matrix(self):
matrix_acess_blk(torch.randn(input_shape))
assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1])

@skipUnless(has_einops, "Requires einops")
@SkipIfBeforePyTorchVersion((2, 0))
def test_flash_attention(self):
for causal in [True, False]:
input_param = {
"hidden_size": 128,
"num_heads": 1,
"causal": causal,
"sequence_length": 16 if causal else None,
}
device = "cuda:0" if torch.cuda.is_available() else "cpu"
block_w_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device)
block_wo_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device)
block_wo_flash_attention.load_state_dict(block_w_flash_attention.state_dict())
test_data = torch.randn(1, 16, 128).to(device)

out_1 = block_w_flash_attention(test_data)
out_2 = block_wo_flash_attention(test_data)
assert_allclose(out_1, out_2, atol=1e-4)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
unittest.main()
Loading
Loading