Skip to content

Commit

Permalink
Add include_fc and use_combined_linear argument in the SABlock (P…
Browse files Browse the repository at this point in the history
…roject-MONAI#7996)

Fixes Project-MONAI#7991
Fixes Project-MONAI#7992

### Description
Add `include_fc` and `use_combined_linear` argument in the `SABlock`.


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
  • Loading branch information
3 people authored and rcremese committed Sep 2, 2024
1 parent 7bfb46c commit 0d7744f
Show file tree
Hide file tree
Showing 13 changed files with 426 additions and 137 deletions.
33 changes: 11 additions & 22 deletions monai/networks/blocks/crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,12 @@ def __init__(
causal (bool, optional): whether to use causal attention.
sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length.
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only
"decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
"decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional
parameter size.
parameter size.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use Pytorch's inbuilt
flash attention for a memory efficient attention mechanism (see
https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""

super().__init__()
Expand Down Expand Up @@ -109,7 +108,7 @@ def __init__(
self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)

self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.out_rearrange = Rearrange("b l h d -> b h (l d)")
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
self.dropout_rate = dropout_rate
Expand Down Expand Up @@ -152,31 +151,20 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size)

q = self.to_q(x)
q = self.input_rearrange(self.to_q(x))
kv = context if context is not None else x
_, kv_t, _ = kv.size()
k = self.to_k(kv)
v = self.to_v(kv)
k = self.input_rearrange(self.to_k(kv))
v = self.input_rearrange(self.to_v(kv))

if self.attention_dtype is not None:
q = q.to(self.attention_dtype)
k = k.to(self.attention_dtype)

q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) #
k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)

if self.use_flash_attention:
x = torch.nn.functional.scaled_dot_product_attention(
query=q.transpose(1, 2),
key=k.transpose(1, 2),
value=v.transpose(1, 2),
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
).transpose(
1, 2
) # Back to (b, nh, t, hs)
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
# apply relative positional embedding if defined
Expand All @@ -195,6 +183,7 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)

x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
Expand Down
60 changes: 40 additions & 20 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from __future__ import annotations

from typing import Optional, Tuple
from typing import Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -40,9 +40,11 @@ def __init__(
hidden_input_size: int | None = None,
causal: bool = False,
sequence_length: int | None = None,
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
attention_dtype: Optional[torch.dtype] = None,
rel_pos_embedding: str | None = None,
input_size: Tuple | None = None,
attention_dtype: torch.dtype | None = None,
include_fc: bool = True,
use_combined_linear: bool = True,
use_flash_attention: bool = False,
) -> None:
"""
Expand All @@ -61,9 +63,10 @@ def __init__(
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use Pytorch's inbuilt
flash attention for a memory efficient attention mechanism (see
https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
include_fc: whether to include the final linear layer. Default to True.
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""

Expand Down Expand Up @@ -105,9 +108,22 @@ def __init__(
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)

self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias)
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.qkv: Union[nn.Linear, nn.Identity]
self.to_q: Union[nn.Linear, nn.Identity]
self.to_k: Union[nn.Linear, nn.Identity]
self.to_v: Union[nn.Linear, nn.Identity]

if use_combined_linear:
self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias)
self.to_q = self.to_k = self.to_v = nn.Identity() # add to enable torchscript
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
else:
self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
self.qkv = nn.Identity() # add to enable torchscript
self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)
self.out_rearrange = Rearrange("b l h d -> b h (l d)")
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
self.dropout_rate = dropout_rate
Expand All @@ -117,6 +133,8 @@ def __init__(
self.attention_dtype = attention_dtype
self.causal = causal
self.sequence_length = sequence_length
self.include_fc = include_fc
self.use_combined_linear = use_combined_linear
self.use_flash_attention = use_flash_attention

if causal and sequence_length is not None:
Expand Down Expand Up @@ -144,22 +162,22 @@ def forward(self, x):
Return:
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
"""
output = self.input_rearrange(self.qkv(x))
q, k, v = output[0], output[1], output[2]
if self.use_combined_linear:
output = self.input_rearrange(self.qkv(x))
q, k, v = output[0], output[1], output[2]
else:
q = self.input_rearrange(self.to_q(x))
k = self.input_rearrange(self.to_k(x))
v = self.input_rearrange(self.to_v(x))

if self.attention_dtype is not None:
q = q.to(self.attention_dtype)
k = k.to(self.attention_dtype)

if self.use_flash_attention:
x = F.scaled_dot_product_attention(
query=q.transpose(1, 2),
key=k.transpose(1, 2),
value=v.transpose(1, 2),
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
).transpose(1, 2)
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

Expand All @@ -179,7 +197,9 @@ def forward(self, x):

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)

x = self.out_rearrange(x)
x = self.out_proj(x)
if self.include_fc:
x = self.out_proj(x)
x = self.drop_output(x)
return x
11 changes: 10 additions & 1 deletion monai/networks/blocks/spatialattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ class SpatialAttentionBlock(nn.Module):
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
num_channels: number of input channels. Must be divisible by num_head_channels.
num_head_channels: number of channels per head.
norm_num_groups: Number of groups for the group norm layer.
norm_eps: Epsilon for the normalization.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
include_fc: whether to include the final linear layer. Default to True.
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""

Expand All @@ -45,6 +50,8 @@ def __init__(
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
attention_dtype: Optional[torch.dtype] = None,
include_fc: bool = True,
use_combined_linear: bool = False,
use_flash_attention: bool = False,
) -> None:
super().__init__()
Expand All @@ -60,6 +67,8 @@ def __init__(
num_heads=num_heads,
qkv_bias=True,
attention_dtype=attention_dtype,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)

Expand Down
8 changes: 7 additions & 1 deletion monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(
sequence_length: int | None = None,
with_cross_attention: bool = False,
use_flash_attention: bool = False,
include_fc: bool = True,
use_combined_linear: bool = True,
) -> None:
"""
Args:
Expand All @@ -47,7 +49,9 @@ def __init__(
qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
include_fc: whether to include the final linear layer. Default to True.
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
"""

Expand All @@ -69,6 +73,8 @@ def __init__(
save_attn=save_attn,
causal=causal,
sequence_length=sequence_length,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)
self.norm2 = nn.LayerNorm(hidden_size)
Expand Down
Loading

0 comments on commit 0d7744f

Please sign in to comment.