-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Additional mask support on FA2 #19
Conversation
Tested on mask addition for one block. Code to be cleaned. Softmax in bwd is removed for testing.
Looks correct on ouput and dv.
Remove some unused code and comments.
@@ -86,8 +86,12 @@ void set_params_fprop(Flash_fwd_params ¶ms, | |||
void * const softmax_lse_d, | |||
float p_dropout, | |||
float softmax_scale, | |||
float softmax_unscale, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个为啥要用参数传进来呢,可以在Kernel里面计算?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个为啥要用参数传进来呢,可以在Kernel里面计算?
这个主要是想用乘法代替除法,原因如下,不过我没有对比过两种实现的性能差异:
- nv论坛有人在v100上做过实验,除法比乘法慢很多。https://forums.developer.nvidia.com/t/speed-comparison-of-division-compared-to-other-arithmetic-operations-perhaps-something-like-clock-cycles/168371
- 我没有在a100上做过实验,但是按照经验,乘法应该是比除法更快。像在CPU上用试商法实现的除法指令,时钟周期也是很长的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
存到Param里面,整个Param都是需要从CPU传输到GPU的。我的意思是,可以在启动CUDA Kernel后,再有softmax_scale
计算出来给后面使用?
csrc/capi/flash_attn.cu
Outdated
ASSERT_CHECK(head_size <= 256); | ||
ASSERT_CHECK(num_heads == num_heads_k); | ||
|
||
if (attn_mask) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
也加一下mask_dims[0]
的检查吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -295,8 +325,12 @@ bool flash_attn_fwd(const void * const q, | |||
softmax_lse_ptr, | |||
p_dropout, | |||
softmax_scale, | |||
softmax_unscale, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
缩进又有点错乱了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
csrc/capi/flash_attn.cu
Outdated
ASSERT_CHECK(mask_dims[1] == 1 || mask_dims[1] == num_heads); | ||
ASSERT_CHECK(mask_dims[2] == 1 || mask_dims[2] == seqlen_q); | ||
#if 0 | ||
ASSERT_CHECK(softmax_scale == 1.0f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以删除
csrc/capi/flash_attn.cu
Outdated
@@ -370,8 +416,12 @@ bool flash_attn_varlen_fwd(const void * const q, | |||
softmax_lse_ptr, | |||
p_dropout, | |||
softmax_scale, | |||
1, // just hack |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是什么作用?
Use params.seqlen for mask predication.
@@ -64,6 +64,30 @@ const char *flash_attn_error() { | |||
#define FLASHATTNLIB_BEGIN_FUNC try { | |||
#define FLASHATTNLIB_END_FUNC } catch (::std::exception &__e) { flash_attn_set_error(__e.what()); return false; } catch (...) { flash_attn_set_error(nullptr); return false; } | |||
|
|||
#define CHECK_FWD_EXECTUABLE(__seqlen_q, __seqlen_k) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉还是定义成函数的形式比较好。
@@ -86,8 +86,12 @@ void set_params_fprop(Flash_fwd_params ¶ms, | |||
void * const softmax_lse_d, | |||
float p_dropout, | |||
float softmax_scale, | |||
float softmax_unscale, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
存到Param里面,整个Param都是需要从CPU传输到GPU的。我的意思是,可以在启动CUDA Kernel后,再有softmax_scale
计算出来给后面使用?
num_splits, | ||
const_cast<void *>(attn_mask), | ||
mask_head_mod_size, | ||
mask_seq_q_mod_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我感觉,变量名叫mask_num_heads
、mask_seqlen_q
,会不会好一些
@@ -101,6 +102,11 @@ struct Flash_fwd_params : public Qkv_params { | |||
|
|||
bool is_bf16; | |||
bool is_causal; | |||
|
|||
// The attn mask matrix | |||
void * __restrict__ attn_mask_ptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
L85有一个int
类型的blockmask
指针,是干啥用的呢?
@@ -448,7 +448,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in | |||
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb); | |||
if (n_block * kBlockN >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return; | |||
|
|||
int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); | |||
// umiswing: residue is for predication of additional mask gmem access. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些额外的计算,是不是可以加上if (Is_attn_mask)
判断?另外,我感觉additional mask gmem access
的逻辑,封装一下比较好?不然对Kernel侵入太多,后续合并最新的代码会比较困难。
m_block == m_block_max - 1 ? m_residue : params.seqlen_q, | ||
n_block == n_block_max - 1 ? n_residue : params.seqlen_k, | ||
params.unscale_softmax); | ||
tPgMask.data() = tPgMask.data() + (-kBlockM * params.seqlen_k); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一行代码是在进行指针的变换吗?加个注释?
} | ||
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params); | ||
C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
BOOL_SWITCH(is_attn_mask, Is_attn_mask, [&] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是会引入IsCausal
和Is_attn_mask
都为true
的编译?
// TODO(umiswing): support cu_attn_mask | ||
// This kernel should work after dealing with input cu_seq indicating mask position. | ||
template <typename Engine, typename Layout, typename T> | ||
inline __device__ void apply_cu_attn_mask(Tensor<Engine, Layout> &tensor, const T* const mask, const float unscale_softmax, const uint32_t col_idx_offset_, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数没有用到?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数没有用到?
这个是之前提的阶梯状mask的需求相关的kernel,即通过row_idx, col_idx来判断是否mask out。这个kernel还没有完全写完,要在这个PR中保留吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Great work~
Support additional mask on fa2.
so size: 66M -> 96M