Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention #7977

Merged
merged 12 commits into from
Aug 6, 2024
Merged
74 changes: 45 additions & 29 deletions monai/networks/blocks/crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.nn as nn

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.utils import optional_import
from monai.utils import optional_import, pytorch_after

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")

Expand All @@ -44,24 +44,23 @@ def __init__(
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
attention_dtype: Optional[torch.dtype] = None,
use_flash_attention: bool = False,
) -> None:
"""
Args:
hidden_size (int): dimension of hidden layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size.
context_input_size (int, optional): dimension of the context tensor. Defaults to hidden_size.
dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
causal: whether to use causal attention.
sequence_length: if causal is True, it is necessary to specify the sequence length.
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
attention_dtype: cast attention operations to this dtype.
Args: hidden_size (int): dimension of hidden layer. num_heads (int): number of attention heads. dropout_rate
(float, optional): fraction of the input units to drop. Defaults to 0.0. hidden_input_size (int, optional):
dimension of the input tensor. Defaults to hidden_size. context_input_size (int, optional): dimension of the
context tensor. Defaults to hidden_size. dim_head (int, optional): dimension of each head. Defaults to
hidden_size // num_heads. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. causal: whether to
use causal attention. sequence_length: if causal is True, it is necessary to specify the sequence length.
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only
"decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. input_size (tuple(
spatial_dim), optional): Input resolution for calculating the relative positional parameter size.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use Pytorch's inbuilt
flash attention for a memory efficient attention mechanism (see
https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""

super().__init__()
Expand All @@ -81,6 +80,17 @@ def __init__(
if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
raise ValueError(
"use_flash_attention is only supported for PyTorch versions >= 2.0."
"Upgrade your PyTorch or set the flag to False."
)
if use_flash_attention and save_attn:
raise ValueError(
"save_attn has been set to True, but use_flash_attention is also set"
"to True. save_attn can only be used if use_flash_attention is False"
)

self.num_heads = num_heads
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.context_input_size = context_input_size if context_input_size else hidden_size
Expand All @@ -101,6 +111,7 @@ def __init__(

self.causal = causal
self.sequence_length = sequence_length
self.use_flash_attention = use_flash_attention

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
Expand Down Expand Up @@ -145,23 +156,28 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
if self.use_flash_attention:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v).contiguous()
else:

att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
# apply relative positional embedding if defined
if self.rel_positional_embedding is not None:
att_mat = self.rel_positional_embedding(x, att_mat, q)

if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
Expand Down
47 changes: 34 additions & 13 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.utils import optional_import
from monai.utils import optional_import, pytorch_after

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")

Expand All @@ -42,6 +43,7 @@ def __init__(
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
attention_dtype: Optional[torch.dtype] = None,
use_flash_attention: bool = False,
) -> None:
"""
Args:
Expand All @@ -59,6 +61,9 @@ def __init__(
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use Pytorch's inbuilt
flash attention for a memory efficient attention mechanism (see
https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).

"""

Expand All @@ -82,6 +87,17 @@ def __init__(
if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
raise ValueError(
"use_flash_attention is only supported for PyTorch versions >= 2.0."
"Upgrade your PyTorch or set the flag to False."
)
if use_flash_attention and save_attn:
raise ValueError(
"save_attn has been set to True, but use_flash_attention is also set"
"to True. save_attn can only be used if use_flash_attention is False"
)

self.num_heads = num_heads
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)
Expand All @@ -97,6 +113,7 @@ def __init__(
self.attention_dtype = attention_dtype
self.causal = causal
self.sequence_length = sequence_length
self.use_flash_attention = use_flash_attention

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
Expand Down Expand Up @@ -130,23 +147,27 @@ def forward(self, x):
q = q.to(self.attention_dtype)
k = k.to(self.attention_dtype)

att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
if self.use_flash_attention:
x = F.scaled_dot_product_attention(q, k, v)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
# apply relative positional embedding if defined
if self.rel_positional_embedding is not None:
att_mat = self.rel_positional_embedding(x, att_mat, q)

if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))
if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))
virginiafdez marked this conversation as resolved.
Show resolved Hide resolved

att_mat = att_mat.softmax(dim=-1)
att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
Expand Down
8 changes: 7 additions & 1 deletion monai/networks/blocks/spatialattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SpatialAttentionBlock(nn.Module):
num_channels: number of input channels. Must be divisible by num_head_channels.
num_head_channels: number of channels per head.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.

"""

Expand All @@ -44,6 +45,7 @@ def __init__(
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
attention_dtype: Optional[torch.dtype] = None,
use_flash_attention: bool = False,
) -> None:
super().__init__()

Expand All @@ -54,7 +56,11 @@ def __init__(
raise ValueError("num_channels must be divisible by num_head_channels")
num_heads = num_channels // num_head_channels if num_head_channels is not None else 1
self.attn = SABlock(
hidden_size=num_channels, num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype
hidden_size=num_channels,
num_heads=num_heads,
qkv_bias=True,
attention_dtype=attention_dtype,
use_flash_attention=use_flash_attention,
)

def forward(self, x: torch.Tensor):
Expand Down
22 changes: 14 additions & 8 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def __init__(
causal: bool = False,
sequence_length: int | None = None,
with_cross_attention: bool = False,
use_flash_attention: bool = False,
) -> None:
"""
Args:
hidden_size (int): dimension of hidden layer.
mlp_dim (int): dimension of feedforward layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
Args: hidden_size (int): dimension of hidden layer. mlp_dim (int): dimension of feedforward layer. num_heads
virginiafdez marked this conversation as resolved.
Show resolved Hide resolved
(int): number of attention heads. dropout_rate (float, optional): fraction of the input units to drop.
Defaults to 0.0. qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).

"""

Expand All @@ -66,13 +66,19 @@ def __init__(
save_attn=save_attn,
causal=causal,
sequence_length=sequence_length,
use_flash_attention=use_flash_attention,
)
self.norm2 = nn.LayerNorm(hidden_size)
self.with_cross_attention = with_cross_attention

self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=False,
use_flash_attention=use_flash_attention,
)

def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
Expand Down
5 changes: 5 additions & 0 deletions monai/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class DiffusionUNetTransformerBlock(nn.Module):
dropout: dropout probability to use.
cross_attention_dim: size of the context vector for cross attention.
upcast_attention: if True, upcast attention operations to full precision.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).

"""

Expand All @@ -77,6 +79,7 @@ def __init__(
dropout: float = 0.0,
cross_attention_dim: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
) -> None:
super().__init__()
self.attn1 = SABlock(
Expand All @@ -86,6 +89,7 @@ def __init__(
dim_head=num_head_channels,
dropout_rate=dropout,
attention_dtype=torch.float if upcast_attention else None,
use_flash_attention=use_flash_attention,
)
self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout)
self.attn2 = CrossAttentionBlock(
Expand All @@ -96,6 +100,7 @@ def __init__(
dim_head=num_head_channels,
dropout_rate=dropout,
attention_dtype=torch.float if upcast_attention else None,
use_flash_attention=use_flash_attention,
)
self.norm1 = nn.LayerNorm(num_channels)
self.norm2 = nn.LayerNorm(num_channels)
Expand Down
37 changes: 24 additions & 13 deletions tests/test_crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from monai.networks.blocks.crossattention import CrossAttentionBlock
from monai.networks.layers.factories import RelPosEmbedding
from monai.utils import optional_import
from tests.utils import SkipIfBeforePyTorchVersion

einops, has_einops = optional_import("einops")

Expand All @@ -31,25 +32,28 @@
for num_heads in [4, 6, 8, 12]:
for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]:
for input_size in [(16, 32), (8, 8, 8)]:
test_case = [
{
"hidden_size": hidden_size,
"num_heads": num_heads,
"dropout_rate": dropout_rate,
"rel_pos_embedding": rel_pos_embedding,
"input_size": input_size,
},
(2, 512, hidden_size),
(2, 512, hidden_size),
]
TEST_CASE_CABLOCK.append(test_case)
for flash_attn in [True, False]:
test_case = [
{
"hidden_size": hidden_size,
"num_heads": num_heads,
"dropout_rate": dropout_rate,
"rel_pos_embedding": rel_pos_embedding,
"input_size": input_size,
"use_flash_attention": flash_attn,
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
},
(2, 512, hidden_size),
(2, 512, hidden_size),
]
TEST_CASE_CABLOCK.append(test_case)


class TestResBlock(unittest.TestCase):

@parameterized.expand(TEST_CASE_CABLOCK)
@skipUnless(has_einops, "Requires einops")
def test_shape(self, input_param, input_shape, expected_shape):
ericspod marked this conversation as resolved.
Show resolved Hide resolved
# Without flash attention
net = CrossAttentionBlock(**input_param)
with eval_mode(net):
result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"]))
Expand All @@ -62,6 +66,13 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
CrossAttentionBlock(hidden_size=620, num_heads=8, dropout_rate=0.4)

@SkipIfBeforePyTorchVersion((2, 0))
def test_save_attn_with_flash_attention(self):
with self.assertRaises(ValueError):
CrossAttentionBlock(
hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True
)

@skipUnless(has_einops, "Requires einops")
def test_attention_dim_not_multiple_of_heads(self):
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -119,7 +130,7 @@ def test_access_attn_matrix(self):
# no of elements is zero
assert no_matrix_acess_blk.att_mat.nelement() == 0

# be able to acess the attention matrix
# be able to acess the attention matrix.
matrix_acess_blk = CrossAttentionBlock(
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True
)
Expand Down
Loading
Loading