Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attn_mask supported for FlashAttnKernel. #55969

Merged
merged 14 commits into from
Aug 7, 2023

Conversation

iosmers
Copy link
Contributor

@iosmers iosmers commented Aug 3, 2023

PR types

Others

PR changes

Others

Description

Pcard-70459

FlashAttnKernel支持Tensor类型输入的attn_mask。

#55758 已经将支持常规Causal类型Mask的FlashAttention-2集成到框架,由于Tensor类型的Mask需要对FlashAttention-2进行Kernel功能增强,故本PR目前集成的是基于Flash-Attention-1的Mask。性能数据如下:

  • 已知在seqlen较大时,当前基于Flash-Attention-1的Tensor类型Mask实现性能很差,后续PR再进行优化。
Shape   Native Causal FlashAttn-1 Causal FlashAttn-2 Masked FlashAttn-1
[2, 1024, 40, 128] 前向 27.17 3.77 1.65 12.20
[2, 1024, 40, 128] 前向+反向 56.52 14.23 8.41 36.53
[1, 8192, 8, 128] 前向 168.36 20.95 8.44 671.29
[1, 8192, 8, 128] 前向+反向 351.07 62.29 30.81 1950.40

@paddle-bot
Copy link

paddle-bot bot commented Aug 3, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

)

out = scaled_dot_product_attention(
q, k, v, m, self.dropout, self.causal, fixed_seed_offset=None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scaled_dot_product_attention是加了Tensor mask m的,但attention_naive没有加m

@@ -293,6 +293,48 @@ def test_all(self):
fetches_result[0], out_, rtol=5e-03, atol=1e-03
)

def test_dot_scale_product(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

重新实现一个单测的类会比较好,因为后面还有好几个单测是以TestFlashAttentionAPI为基类,并且修改了shape、return_softmax等配置,当前这种修改方式会都会测试mask版本

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一下如下判断,跳过不支持的cuda版本和硬件:

if not core.is_compiled_with_cuda() or get_cuda_version() < 11030 or not is_sm_supported:
    pass

}
} else {
succ =
phi::dynload::flash_attn_fwd(q.data(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码再封装一下吧,这个函数太长了

@@ -818,8 +818,9 @@
inplace : (out_grad -> x_grad)

- backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset,Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

,Tensor attn_mask,后面加空格

args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset,Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false)
optional : attn_mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed_seed_offset也要加到optional吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed_seed_offset是原有参数,类型为const Tensor,不是optional

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉原写法也不是很合理,先保持原样吧

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没问题了,反向的输入是seed_offset,是前向的输出,是必须的。

@@ -33,6 +34,24 @@ DECLARE_bool(cudnn_deterministic);

namespace phi {

// template <typename T, typename Context>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除无用的代码

SimleScaleWithMaskKernel<<<gpu_config.block_per_grid,
gpu_config.thread_per_block,
0,
ctx.stream()>>>(q_size, scale, q_ptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L163定义了一个临时Tensor,SimleScaleWithMaskKernel是不是应该读取原始输入q的数据,scale后写入q_

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数的作用是在原地scale


auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, q_size, 1);
DenseTensor q_(q);
T* q_ptr = static_cast<T*>(q_.data<T>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor需要先Resize设置维度,再Alloc空间,否则是没有分配显存的。

int64_t q_size = total_q * num_heads * head_size;

auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, q_size, 1);
DenseTensor q_(q);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q_ -> scaled_q

fixed_seed_offset=None,
return_softmax=False,
training=True,
rng_name="",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只新增training参数,其他的参数先不要加

bool succ;

if (attn_mask.get_ptr()) {
scale = 1.0f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一个PADDLE_ENFORCE检查,即传入了attn_mask时,is_causal不能为true

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意报错信息的写法,使用:

PADDLE_ENFORCE_NE(causal, true, phi::errors::InvalidArguemnts(....));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scale是通过参数传进来的,不应该修改输入scale的值,而是flash_attnscale参数传1,后续SimleScaleWithMaskKernel函数调用传这个输入的scale

args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset,Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false)
optional : attn_mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉原写法也不是很合理,先保持原样吧

bool succ;

if (attn_mask.get_ptr()) {
scale = 1.0f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意报错信息的写法,使用:

PADDLE_ENFORCE_NE(causal, true, phi::errors::InvalidArguemnts(....));

bool succ;

if (attn_mask.get_ptr()) {
scale = 1.0f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scale是通过参数传进来的,不应该修改输入scale的值,而是flash_attnscale参数传1,后续SimleScaleWithMaskKernel函数调用传这个输入的scale


int64_t q_size = total_q * num_heads * head_size;
DenseTensor scale_q;
scale_q.ShareDataWith(q).Resize({total_q, num_heads, head_size});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ShareDataWith不合适,因为ComputeScaleQ中相当于会把输入的值改了,不确定会出现什么问题。scaled_q应该自己申请一块空间,ComputeScaleQ应该改成非inplace的版本。

// compute scale Q
ComputeScaleQ(ctx, q_size, scale_q.data<T>(), scale);

scale = 1.0f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要修改函数输入参数的值

}
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类型转换用static_cast<int64_t>

temp_rand_mask_dim.data() ? temp_rand_mask_dim.data() : nullptr,
nullptr);
PADDLE_ENFORCE_EQ(
succ, true, phi::errors::External(phi::dynload::flash_attn_error()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

封装一个函数,并且报错信息多加一些,写成类似“Error in Flash-Attention,detail information is xxx”,xxx是phi::dynload::flash_attn_error()

"attn_mask is not nullptr, causal can not be true"));

int64_t q_size = total_q * num_heads * head_size;
DenseTensor* scale_q = new DenseTensor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接用DensorTensor scale_q;定义Tensor。

attn_mask=None,
dropout_p=0.0,
is_causal=False,
training=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增了参数,需要在L453添加参数说明的文档,可以参考下其他API

true,
phi::errors::InvalidArgument(
"attn_mask is not nullptr, causal can not be true"));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考如下函数,支持Tensor Mask的Flash-Attention功能存在一定的限制,即if (head_dim == 32 || head_dim == 64 || head_dim == 128),这里也需要加个PADDLE_ENFORCE判断下:

bool CanUseFlashAttn() const {
#ifdef PADDLE_WITH_FLASHATTN
if (!std::is_same<T, phi::dtype::bfloat16>::value &&
!std::is_same<T, phi::dtype::float16>::value) {
return false;
}
if (merge_qkv && batch_size == 1) {
if (head_dim == 32 || head_dim == 64 || head_dim == 128) {
return use_flash_attn;
}
}
#endif
return false;
}

@@ -19,6 +19,74 @@

namespace phi {

template <typename T, typename Context>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该头文件中,不需要加这个函数声明。

const int64_t* mask_dims);

template <typename T, typename Context>
void FlashAttnFwd(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该头文件中,不需要加这个函数声明。


PADDLE_ENFORCE_EQ(succ,
true,
"Error in Flash-Attention, detail information is ",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

报错信息依然需要加错误类型,phi::errors::External(...)

@Xreki
Copy link
Contributor

Xreki commented Aug 7, 2023

本地(40G-A100、CUDA11.8)单测执行结果如下:

test 1307
    Start 1307: test_flash_attention

1307: Test command: /root/paddlejob/workspace/work/liuyiqun/Paddle/build_paddle/cmake-3.18.0-Linux-x86_64/bin/cmake "-E" "env" "PYTHONPATH=/root/paddlejob/workspace/work/liuyiqun/Paddle/build_paddle/build_cuda11.8_gcc8.2.0_py3.8/python" "/usr/bin/python3.8" "/root/paddlejob/workspace/work/liuyiqun/Paddle/tools/test_runner.py" "test_flash_attention"
1307: Test timeout computed to be: 10000000
1307: W0807 15:55:55.975852 139275 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 8.0, Driver API Version: 11.8, Runtime API Version: 11.8
1307: W0807 15:55:56.008502 139275 gpu_resources.cc:149] device: 0, cuDNN Version: 8.6.
1307: I0807 15:55:59.264534 139275 program_interpreter.cc:173] New Executor is Running.
1307: /root/paddlejob/workspace/work/liuyiqun/Paddle/build_paddle/build_cuda11.8_gcc8.2.0_py3.8/python/paddle/fluid/framework.py:2832: UserWarning: The Attr(force_cpu) of Op(fill_constant) will be deprecated in the future, please use 'device_guard' instead. 'device_guard' has higher priority when they are used at the same time.
1307:   warnings.warn(
1307: /root/paddlejob/workspace/work/liuyiqun/Paddle/build_paddle/build_cuda11.8_gcc8.2.0_py3.8/python/paddle/fluid/data_feeder.py:177: UserWarning: The data type of 'x' in reshape only support float16 in GPU now. 
1307:   warnings.warn(
1307: /root/paddlejob/workspace/work/liuyiqun/Paddle/build_paddle/build_cuda11.8_gcc8.2.0_py3.8/python/paddle/fluid/executor.py:1243: UserWarning: The variable k is not found in program. It is not declared or is pruned.
1307:   warnings.warn(
1307: /root/paddlejob/workspace/work/liuyiqun/Paddle/build_paddle/build_cuda11.8_gcc8.2.0_py3.8/python/paddle/fluid/executor.py:1243: UserWarning: The variable v is not found in program. It is not declared or is pruned.
1307:   warnings.warn(
1307: /root/paddlejob/workspace/work/liuyiqun/Paddle/build_paddle/build_cuda11.8_gcc8.2.0_py3.8/python/paddle/fluid/data_feeder.py:177: UserWarning: The data type of 'x' in transpose only support float16 in GPU now. 
1307:   warnings.warn(
1307: /root/paddlejob/workspace/work/liuyiqun/Paddle/build_paddle/build_cuda11.8_gcc8.2.0_py3.8/python/paddle/fluid/data_feeder.py:177: UserWarning: The data type of 'x' in matmul only support float16 in GPU now. 
1307:   warnings.warn(
1307: /root/paddlejob/workspace/work/liuyiqun/Paddle/build_paddle/build_cuda11.8_gcc8.2.0_py3.8/python/paddle/fluid/data_feeder.py:177: UserWarning: The data type of 'y' in matmul only support float16 in GPU now. 
1307:   warnings.warn(
1307: /root/paddlejob/workspace/work/liuyiqun/Paddle/build_paddle/build_cuda11.8_gcc8.2.0_py3.8/python/paddle/fluid/data_feeder.py:177: UserWarning: The data type of 'x' in softmax only support float16 in GPU now. 
1307:   warnings.warn(
1307: Test case shape (2, 128, 8, 16) dtype float16 causal False
1307: Test unpadded case shape (2, 128, 8, 16) dtype float16 causal False
1307: Test case shape (2, 128, 8, 16) dtype paddle.float16 causal False
1307: Test unpadded case shape (2, 128, 8, 16) dtype paddle.float16 causal False
1307: Test case shape (2, 256, 8, 16) dtype paddle.float16 causal False
1307: Test unpadded case shape (2, 256, 8, 16) dtype paddle.float16 causal False
1307: Test case shape (2, 512, 8, 16) dtype paddle.float16 causal True
1307: Test unpadded case shape (2, 512, 8, 16) dtype paddle.float16 causal True
1307: Test case shape (8, 1024, 16, 128) dtype paddle.float16 causal False
1307: Test unpadded case shape (8, 1024, 16, 128) dtype paddle.float16 causal False
1307: Test case shape (8, 1024, 16, 128) dtype paddle.float16 causal False
1307: Test unpadded case shape (8, 1024, 16, 128) dtype paddle.float16 causal False
1307: Test case shape (8, 1024, 16, 128) dtype paddle.float16 causal False
1307: Test unpadded case shape (8, 1024, 16, 128) dtype paddle.float16 causal False
1/1 Test #1307: test_flash_attention .............   Passed   56.05 sec

The following tests passed:
        test_flash_attention

100% tests passed, 0 tests failed out of 1

Total Test time (real) =  56.15 sec

Copy link
Contributor

@JamesLim-sy JamesLim-sy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for skipIf

@Xreki Xreki changed the title add mask Add attn_mask suupprted for FlashAttnKernel. Aug 7, 2023
@Xreki Xreki changed the title Add attn_mask suupprted for FlashAttnKernel. Add attn_mask supported for FlashAttnKernel. Aug 7, 2023
Copy link
Contributor

@lanxianghit lanxianghit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for new args

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for yaml change

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

  • Doc-Preview 的 CI 有一些问题,导致一直没过,暂时排查不出。先aprrove

@Xreki Xreki merged commit 42e0c6b into PaddlePaddle:develop Aug 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants