Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional mask support on FA2 #19

Merged
merged 20 commits into from
Sep 22, 2023
Merged

Additional mask support on FA2 #19

merged 20 commits into from
Sep 22, 2023

Conversation

umiswing
Copy link
Member

@umiswing umiswing commented Sep 13, 2023

Support additional mask on fa2.

so size: 66M -> 96M

@CLAassistant
Copy link

CLAassistant commented Sep 13, 2023

CLA assistant check
All committers have signed the CLA.

@@ -86,8 +86,12 @@ void set_params_fprop(Flash_fwd_params &params,
void * const softmax_lse_d,
float p_dropout,
float softmax_scale,
float softmax_unscale,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个为啥要用参数传进来呢,可以在Kernel里面计算?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个为啥要用参数传进来呢,可以在Kernel里面计算?

这个主要是想用乘法代替除法,原因如下,不过我没有对比过两种实现的性能差异:

  1. nv论坛有人在v100上做过实验,除法比乘法慢很多。https://forums.developer.nvidia.com/t/speed-comparison-of-division-compared-to-other-arithmetic-operations-perhaps-something-like-clock-cycles/168371
  2. 我没有在a100上做过实验,但是按照经验,乘法应该是比除法更快。像在CPU上用试商法实现的除法指令,时钟周期也是很长的。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

存到Param里面,整个Param都是需要从CPU传输到GPU的。我的意思是,可以在启动CUDA Kernel后,再有softmax_scale计算出来给后面使用?

ASSERT_CHECK(head_size <= 256);
ASSERT_CHECK(num_heads == num_heads_k);

if (attn_mask) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

也加一下mask_dims[0]的检查吧

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -295,8 +325,12 @@ bool flash_attn_fwd(const void * const q,
softmax_lse_ptr,
p_dropout,
softmax_scale,
softmax_unscale,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

缩进又有点错乱了

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

ASSERT_CHECK(mask_dims[1] == 1 || mask_dims[1] == num_heads);
ASSERT_CHECK(mask_dims[2] == 1 || mask_dims[2] == seqlen_q);
#if 0
ASSERT_CHECK(softmax_scale == 1.0f);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以删除

@@ -370,8 +416,12 @@ bool flash_attn_varlen_fwd(const void * const q,
softmax_lse_ptr,
p_dropout,
softmax_scale,
1, // just hack
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是什么作用?

@umiswing umiswing changed the title [WIP] Additional mask support on FA2 Additional mask support on FA2 Sep 21, 2023
@@ -64,6 +64,30 @@ const char *flash_attn_error() {
#define FLASHATTNLIB_BEGIN_FUNC try {
#define FLASHATTNLIB_END_FUNC } catch (::std::exception &__e) { flash_attn_set_error(__e.what()); return false; } catch (...) { flash_attn_set_error(nullptr); return false; }

#define CHECK_FWD_EXECTUABLE(__seqlen_q, __seqlen_k) \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉还是定义成函数的形式比较好。

@@ -86,8 +86,12 @@ void set_params_fprop(Flash_fwd_params &params,
void * const softmax_lse_d,
float p_dropout,
float softmax_scale,
float softmax_unscale,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

存到Param里面,整个Param都是需要从CPU传输到GPU的。我的意思是,可以在启动CUDA Kernel后,再有softmax_scale计算出来给后面使用?

num_splits,
const_cast<void *>(attn_mask),
mask_head_mod_size,
mask_seq_q_mod_size);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我感觉,变量名叫mask_num_headsmask_seqlen_q,会不会好一些

@@ -101,6 +102,11 @@ struct Flash_fwd_params : public Qkv_params {

bool is_bf16;
bool is_causal;

// The attn mask matrix
void * __restrict__ attn_mask_ptr;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L85有一个int类型的blockmask指针,是干啥用的呢?

@@ -448,7 +448,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (n_block * kBlockN >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return;

int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
// umiswing: residue is for predication of additional mask gmem access.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些额外的计算,是不是可以加上if (Is_attn_mask)判断?另外,我感觉additional mask gmem access的逻辑,封装一下比较好?不然对Kernel侵入太多,后续合并最新的代码会比较困难。

m_block == m_block_max - 1 ? m_residue : params.seqlen_q,
n_block == n_block_max - 1 ? n_residue : params.seqlen_k,
params.unscale_softmax);
tPgMask.data() = tPgMask.data() + (-kBlockM * params.seqlen_k);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一行代码是在进行指针的变换吗?加个注释?

}
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
BOOL_SWITCH(is_attn_mask, Is_attn_mask, [&] {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是会引入IsCausalIs_attn_mask都为true的编译?

// TODO(umiswing): support cu_attn_mask
// This kernel should work after dealing with input cu_seq indicating mask position.
template <typename Engine, typename Layout, typename T>
inline __device__ void apply_cu_attn_mask(Tensor<Engine, Layout> &tensor, const T* const mask, const float unscale_softmax, const uint32_t col_idx_offset_,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数没有用到?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数没有用到?

这个是之前提的阶梯状mask的需求相关的kernel,即通过row_idx, col_idx来判断是否mask out。这个kernel还没有完全写完,要在这个PR中保留吗?

Copy link
Collaborator

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Great work~

@Xreki Xreki merged commit b74460b into PaddlePaddle:main Sep 22, 2023
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants