-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon No.59】addmm 算子FP16/BF16单测完善 #53111
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
60d7c61
to
11e09f1
Compare
@@ -982,6 +982,25 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA, | |||
}); | |||
} | |||
|
|||
template <> | |||
template <> | |||
inline void Blas<phi::GPUContext>::GEMM(bool transA UNUSED, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数虽然没有用到,但是也可以参考729行的bf16 GEMM实现?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
// The VCOPY does not support the float16, bfloat16 | ||
if (!is_float16_or_bfloat16) { | ||
mt_blas.VCOPY( | ||
total_elems, out_grad.data<MPType>(), input_grad->data<MPType>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里既然不支持fp16和bf16,是不是没必要用MPType
了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不使用MPType编译器有错误,编译器不按程序if判断,按T类型编译,这里如果使用out_grad.data,编译器会提示不支持float16类型参数
@@ -78,19 +107,45 @@ void AddmmGradKernel(const Context& dev_ctx, | |||
Array2(input_grad->dims()[0], input_grad->dims()[1]); | |||
|
|||
if (row_compress && col_compress) { | |||
eigen_dinput.device(place) = | |||
eigen_dout.sum().eval().reshape(eigen_dinput_shape); | |||
eigen_dinput.device(place) = eigen_dout.template cast<MPType>() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里如果是非fp16和bf16的话,多了两个cast感觉会影响性能,要不也增加个分支
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已增加分支
@ZzSean 修改CI已完成 |
int ldb, | ||
phi::dtype::bfloat16 beta, | ||
phi::dtype::bfloat16 *C, | ||
int ldc UNUSED) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个UNUSED
应该删掉吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
context_.GetComputeCapability(), | ||
80, | ||
phi::errors::InvalidArgument( | ||
"rocblas fp16 gemm requires GPU compute capability >= 80," |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该是bf16把
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* Add addmm tests * Fix code
PR types
Others
PR changes
Others
Description
addmm算子FP16/BF16单测完善
float16类型使用cublasHgemm
增加 bfloat16 类型调用cublasGemmEx
使用blas VCOPY, SCAL的float16类型编译错误,增加mt_blas使用MPType类型调用VCOPY, SCAL
反向中VCOPY, SCAL不支持float16, bfloat16,增加CopyOrScaleFunctor修改数据
提交CI测试需要设置精度float16 check_out atol=1e-2
如果float16设置 self.check_output(atol=1e-2), PR-CI-Model-benchmark CI有错误,然后修改为verify_output比较