-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add num_splist for flash_attn_bwd and FlashAttnUnpaddedGradKernel #56363
add num_splist for flash_attn_bwd and FlashAttnUnpaddedGradKernel #56363
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) | ||
self.test_dot_scale_product() | ||
paddle.set_flags({'FLAGS_cudnn_deterministic': 0}) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需添加单测对比2次执行结果,输入相同时,要保证得到完全一样的输出,结果比较用np.equal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR需要更新flash-attention的submodule,ci才能测试到flash-attention repo的修改
np.testing.assert_allclose(out1.numpy(), out1_, rtol=5e-03, atol=1e-03) | ||
|
||
out2, out2_ = self.get_out_data() | ||
np.equal(out1.numpy(), out2.numpy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.equal
只是得到比较的结果,还需要加assert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fa的PR还没有合入 submodule暂时无法更新
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了在PR中验证正确性,可以先把submodule更新到自己fork的版本
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 更新一下flash-attn的submodule
- 补充一下算子级和模型级的性能测试结果吧,确定性实现肯定比非确定性实现慢,需要看下大约慢多少。
int num_splits = 0; // 0 for an internal heuristic, which is optimal | ||
if (FLAGS_cudnn_deterministic) { | ||
num_splits = 1; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
封装成一个函数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -306,6 +302,29 @@ def test_all(self): | |||
np.testing.assert_allclose( | |||
fetches_result[0], out_, rtol=5e-03, atol=1e-03 | |||
) | |||
return out, out_, fetches_result[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fetches_result[0]
是静态图执行的前向输出吧,确定性实现可以只比较测试动态图,但是需要检查前向out
和反向的dq
、dk
、dv
,保证两次执行的结果完全一样。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
// 0 for an internal heuristic, which is optimal | ||
return FLAGS_cudnn_deterministic ? 1 : 0; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数,以及kernel中两处int num_splits = get_num_split();
建议直接封装在FlashAttnBwdParamsV2里面。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下个PR再改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for skipIf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -0,0 +1,208 @@ | |||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个单测挪出来了吗?那需要加到GPUPS CI跑的单测列表里面去,下个PR加下吧。
…ttnUnpaddedGradKernel (PaddlePaddle#56363) * add num_splist for flash_attn_bwd and FlashAttnUnpaddedGradKernel * Add assertTrue * Update submodule to a specific commit
PR types
Others
PR changes
Others
Description
Pcard-70458
llama模型参考README运行,2CI对别结果如下,loss无diff
具体数据1:
具体数据2:
不开确定算法的性能: