-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add nn.functional.sparse_attention and some test cases, test=develop #35757
Add nn.functional.sparse_attention and some test cases, test=develop #35757
Conversation
Thanks for your contribution! |
… Add_nn_functional_sparse_attention
… Add_nn_functional_sparse_attention
@@ -1747,3 +1747,126 @@ class centers and the shape of sampled_class_center will be [num_positive_class_ | |||
'seed': seed if seed is not None else 0 | |||
}) | |||
return remapped_label, sampled_class_center | |||
|
|||
|
|||
def sparse_attention(query, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议放到一个单独的 .py 文件里,这个感觉不是很 common
paddle_result.numpy(), numpy_result, atol=1e-5)) | ||
|
||
|
||
class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
非2的幂次方的 shape 可以支持吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以支持,单测中增加了非2的幂次方测试。
sparse_csr_columns, | ||
name=None): | ||
r""" | ||
Sparse_attention refers to sparse the Attention matrix in Transformer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好像语法不是很对。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
进行了更改
d represents the size of the last dimension of the three parameters. | ||
|
||
Parameters: | ||
query(Tensor): The query tensor in the Attention module. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
检查一下语法,看看有没有可以提高的地方
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对文档内容进行了更改
not core.is_compiled_with_cuda() or get_suitable_env() == False, | ||
"core is not compiled with CUDA and cuda version need >= 11.2 in windows") | ||
not core.is_compiled_with_cuda() or get_cuda_version() < 11020, | ||
"core is not compiled with CUDA and cuda version need larger than 11.2") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是不是应该改为 larger than or equal to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
sparse_csr_columns, | ||
name=None): | ||
r""" | ||
This operator implements the sparse_attention api. The api sparse the Attention matrix in Transformer module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
30行可以直接改为:This operator sparse the Attention matrix in Transformer module
query(Tensor): The query tensor in the Attention module. | ||
It's a multidimensional tensor with a shape of | ||
:math:`[batch\_size, num\_heads, target\_len, head\_dim]`. | ||
The dtype can be ``float32`` and ``float64``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果是固定的4-D,描述可以直接指明维度,建议 :
query(Tensor): 4-D query Tensor of shape :math:[batch\_size, num\_heads, target\_len, head\_dim]
in the Attention module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
def test_static_result(self): | ||
paddle.enable_static() | ||
with paddle.static.program_guard(paddle.static.Program()): | ||
Q = paddle.static.data(name="Q", shape=self.shape, dtype="float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be dtype=self.dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Parameters: | ||
query(Tensor): The query tensor in the Attention module. | ||
It's a multidimensional tensor with a shape of | ||
:math:`[batch\_size, num\_heads, target\_len, head\_dim]`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以解释下target_len的含义。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成了seq_len
np.allclose( | ||
fetches_result, expected_result, atol=1e-5)) | ||
|
||
def test_dygraph(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_static_result与test_dygraph这两个函数名,看着不像一对,如有必要,可以打磨一下命名。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.进行了更改
b49797d
…addlePaddle#35757) Add paddle.nn.functional.sparse_attention API 本个PR主要将sparse_attention功能在python层进行了一层封装,OP的主体代码见:#PR35676 此外,对于封装的python 接口,增加了相应的单测。
…addlePaddle#35757) Add paddle.nn.functional.sparse_attention API 本个PR主要将sparse_attention功能在python层进行了一层封装,OP的主体代码见:#PR35676 此外,对于封装的python 接口,增加了相应的单测。
PR types
New features
PR changes
APIs
Describe
Add paddle.nn.functional.sparse_attention API
本个PR主要将sparse_attention功能在python层进行了一层封装,OP的主体代码见:#PR35676
此外,对于封装的python 接口,增加了相应的单测。
Example
Result
由于目前CI平台没有CUDA11.2的机器资源,因而将本地计算结果粘贴如下: