Skip to content

Commit

Permalink
Add scaled_dot_product_attention api (#55242)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhenhai93 authored Aug 2, 2023
1 parent ef29468 commit b19dfb8
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 12 deletions.
1 change: 1 addition & 0 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,4 +261,5 @@
'multi_margin_loss',
'soft_margin_loss',
'gaussian_nll_loss',
'scaled_dot_product_attention',
]
55 changes: 54 additions & 1 deletion python/paddle/nn/functional/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,4 +407,57 @@ def flash_attn_unpadded(
return out, softmax if return_softmax else None


scaled_dot_product_attention = flash_attention
def scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
):
r"""
The equation is:
.. math::
result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
The dimensions of the three parameters are the same.
``d`` represents the size of the last dimension of the three parameters.
Warning:
This API only supports inputs with dtype float16 and bfloat16.
Args:
query(Tensor): The query tensor in the Attention module.
4-D tensor with shape:
[batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
key(Tensor): The key tensor in the Attention module.
4-D tensor with shape:
[batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
value(Tensor): The value tensor in the Attention module.
4-D tensor with shape:
[batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
attn_mask(Tensor,optional): A float mask of the same type as query,
key, value that is added to the attention score.
not supported yet.
dropout_p(float): The dropout ratio.
is_causal(bool): Whether enable causal mode.
Returns:
out(Tensor): The attention tensor.
4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim].
The dtype can be float16 or bfloat16.
Examples:
.. code-block:: python
# required: skiptest
>>> # xdoctest: +SKIP()
>>> import paddle
>>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16)
>>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False)
>>> print(output)
>>> # xdoctest: -SKIP
"""
assert attn_mask is None, "attn_mask is not supported yet"
out, _ = flash_attention(query, key, value, dropout_p, is_causal)
return out
51 changes: 40 additions & 11 deletions test/legacy_test/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from paddle.nn.functional.flash_attention import (
flash_attention,
flash_attn_unpadded,
scaled_dot_product_attention,
)


Expand Down Expand Up @@ -85,6 +86,7 @@ def setUp(self):
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = False
self.use_sdp_api = False

def test_unpadded(self):
print(
Expand Down Expand Up @@ -212,9 +214,15 @@ def test_all(self):
enable_flash=self.enable_flash,
enable_mem_efficient=self.enable_mem_efficient,
):
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
)
if self.use_sdp_api:
out = scaled_dot_product_attention(
q, k, v, None, self.dropout, self.causal
)
else:
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
)

else:
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
Expand Down Expand Up @@ -253,14 +261,19 @@ def test_all(self):
enable_flash=self.enable_flash,
enable_mem_efficient=self.enable_mem_efficient,
):
outs, softmax = flash_attention(
qs,
ks,
vs,
self.dropout,
self.causal,
self.return_softmax,
)
if self.use_sdp_api:
outs = scaled_dot_product_attention(
qs, ks, vs, None, self.dropout, self.causal
)
else:
outs, softmax = flash_attention(
qs,
ks,
vs,
self.dropout,
self.causal,
self.return_softmax,
)
else:
outs, softmax = flash_attention(
qs, ks, vs, self.dropout, self.causal, self.return_softmax
Expand Down Expand Up @@ -334,6 +347,22 @@ def setUp(self):
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = True
self.use_sdp_api = False
self.enable_math = True
self.enable_flash = False
self.enable_mem_efficient = False


class TestSDPAttentionAPITest(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (8, 1024, 16, 128)
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = True
self.use_sdp_api = True
self.enable_math = True
self.enable_flash = False
self.enable_mem_efficient = False
Expand Down

0 comments on commit b19dfb8

Please sign in to comment.