Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nn.functional.sparse_attention and some test cases, test=develop #35757

Merged

Conversation

Liu-xiandong
Copy link
Member

@Liu-xiandong Liu-xiandong commented Sep 15, 2021

PR types

New features

PR changes

APIs

Describe

Add paddle.nn.functional.sparse_attention API

  • 本个PR主要将sparse_attention功能在python层进行了一层封装,OP的主体代码见:#PR35676

  • 此外,对于封装的python 接口,增加了相应的单测。

Example

import paddle
import numpy as np

query_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32")
key_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32")
value_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32")
sparse_csr_offset_data = np.array([[[0, 2, 4, 6, 8]]]).astype("int32")
sparse_csr_columns_data = np.array([[[0, 1, 0, 1, 2, 3, 2, 3]]]).astype("int32")
print(query_data.shape)
# (1, 1, 4, 2)
print(sparse_csr_offset_data.shape)
# (1, 1, 5)
print(sparse_csr_columns_data.shape)
# (1, 1, 8)
paddle.disable_static()
query = paddle.to_tensor(query_data, stop_gradient=False, place=paddle.CUDAPlace(0))
key = paddle.to_tensor(key_data, stop_gradient=False, place=paddle.CUDAPlace(0))
value = paddle.to_tensor(value_data, stop_gradient=False, place=paddle.CUDAPlace(0))
offset = paddle.to_tensor(sparse_csr_offset_data, stop_gradient=False, place=paddle.CUDAPlace(0))
columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False, place=paddle.CUDAPlace(0))
output = paddle.nn.functional.sparse_attention(query, key, value, offset, columns)
print(output)

# [[[[1.60885942, 2.60885954],
#       [1.99830270, 2.99830270],
#       [1.60885942, 2.60885954],
#       [1.99830270, 2.99830270]]]]

Result

由于目前CI平台没有CUDA11.2的机器资源,因而将本地计算结果粘贴如下:

  1. 本地单测结果
    image
  2. API example结果
    image

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -1747,3 +1747,126 @@ class centers and the shape of sampled_class_center will be [num_positive_class_
'seed': seed if seed is not None else 0
})
return remapped_label, sampled_class_center


def sparse_attention(query,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议放到一个单独的 .py 文件里,这个感觉不是很 common

paddle_result.numpy(), numpy_result, atol=1e-5))


class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

非2的幂次方的 shape 可以支持吧?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以支持,单测中增加了非2的幂次方测试。

sparse_csr_columns,
name=None):
r"""
Sparse_attention refers to sparse the Attention matrix in Transformer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好像语法不是很对。

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

进行了更改

d represents the size of the last dimension of the three parameters.

Parameters:
query(Tensor): The query tensor in the Attention module.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

检查一下语法,看看有没有可以提高的地方

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对文档内容进行了更改

AnnaTrainingG
AnnaTrainingG previously approved these changes Oct 9, 2021
zkh2016
zkh2016 previously approved these changes Oct 9, 2021
not core.is_compiled_with_cuda() or get_suitable_env() == False,
"core is not compiled with CUDA and cuda version need >= 11.2 in windows")
not core.is_compiled_with_cuda() or get_cuda_version() < 11020,
"core is not compiled with CUDA and cuda version need larger than 11.2")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是应该改为 larger than or equal to?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

sparse_csr_columns,
name=None):
r"""
This operator implements the sparse_attention api. The api sparse the Attention matrix in Transformer module
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

30行可以直接改为:This operator sparse the Attention matrix in Transformer module

query(Tensor): The query tensor in the Attention module.
It's a multidimensional tensor with a shape of
:math:`[batch\_size, num\_heads, target\_len, head\_dim]`.
The dtype can be ``float32`` and ``float64``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是固定的4-D,描述可以直接指明维度,建议 :
query(Tensor): 4-D query Tensor of shape :math:[batch\_size, num\_heads, target\_len, head\_dim]in the Attention module.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

def test_static_result(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
Q = paddle.static.data(name="Q", shape=self.shape, dtype="float32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be dtype=self.dtype

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Parameters:
query(Tensor): The query tensor in the Attention module.
It's a multidimensional tensor with a shape of
:math:`[batch\_size, num\_heads, target\_len, head\_dim]`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以解释下target_len的含义。

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成了seq_len

np.allclose(
fetches_result, expected_result, atol=1e-5))

def test_dygraph(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_static_result与test_dygraph这两个函数名,看着不像一对,如有必要,可以打磨一下命名。

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.进行了更改

@Liu-xiandong Liu-xiandong dismissed stale reviews from zkh2016 and AnnaTrainingG via b49797d October 9, 2021 09:31
@lanxianghit lanxianghit merged commit 85b7723 into PaddlePaddle:develop Oct 11, 2021
Liu-xiandong added a commit to Liu-xiandong/Paddle that referenced this pull request Oct 14, 2021
…addlePaddle#35757)

Add paddle.nn.functional.sparse_attention API

    本个PR主要将sparse_attention功能在python层进行了一层封装,OP的主体代码见:#PR35676

    此外,对于封装的python 接口,增加了相应的单测。
Liu-xiandong added a commit to Liu-xiandong/Paddle that referenced this pull request Oct 19, 2021
…addlePaddle#35757)

Add paddle.nn.functional.sparse_attention API

    本个PR主要将sparse_attention功能在python层进行了一层封装,OP的主体代码见:#PR35676

    此外,对于封装的python 接口,增加了相应的单测。
lanxianghit pushed a commit that referenced this pull request Oct 25, 2021
…35757) (#36551)

Add paddle.nn.functional.sparse_attention API

    本个PR主要将sparse_attention功能在python层进行了一层封装,OP的主体代码见:#PR35676

    此外,对于封装的python 接口,增加了相应的单测。
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants