Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.59】addmm 算子FP16/BF16单测完善 #53111

Merged
merged 2 commits into from
May 5, 2023

Conversation

co63oc
Copy link
Contributor

@co63oc co63oc commented Apr 20, 2023

PR types

Others

PR changes

Others

Description

addmm算子FP16/BF16单测完善
float16类型使用cublasHgemm
增加 bfloat16 类型调用cublasGemmEx
使用blas VCOPY, SCAL的float16类型编译错误,增加mt_blas使用MPType类型调用VCOPY, SCAL

反向中VCOPY, SCAL不支持float16, bfloat16,增加CopyOrScaleFunctor修改数据

提交CI测试需要设置精度float16 check_out atol=1e-2

如果float16设置 self.check_output(atol=1e-2), PR-CI-Model-benchmark CI有错误,然后修改为verify_output比较

@paddle-bot
Copy link

paddle-bot bot commented Apr 20, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@co63oc
Copy link
Contributor Author

co63oc commented Apr 25, 2023

@luotao1 @ZzSean CI已完成

@@ -982,6 +982,25 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA,
});
}

template <>
template <>
inline void Blas<phi::GPUContext>::GEMM(bool transA UNUSED,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数虽然没有用到,但是也可以参考729行的bf16 GEMM实现?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

// The VCOPY does not support the float16, bfloat16
if (!is_float16_or_bfloat16) {
mt_blas.VCOPY(
total_elems, out_grad.data<MPType>(), input_grad->data<MPType>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里既然不支持fp16和bf16,是不是没必要用MPType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不使用MPType编译器有错误,编译器不按程序if判断,按T类型编译,这里如果使用out_grad.data,编译器会提示不支持float16类型参数

@@ -78,19 +107,45 @@ void AddmmGradKernel(const Context& dev_ctx,
Array2(input_grad->dims()[0], input_grad->dims()[1]);

if (row_compress && col_compress) {
eigen_dinput.device(place) =
eigen_dout.sum().eval().reshape(eigen_dinput_shape);
eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
Copy link
Contributor

@ZzSean ZzSean Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里如果是非fp16和bf16的话,多了两个cast感觉会影响性能,要不也增加个分支

Copy link
Contributor Author

@co63oc co63oc Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已增加分支

@co63oc
Copy link
Contributor Author

co63oc commented Apr 27, 2023

@ZzSean 修改CI已完成

int ldb,
phi::dtype::bfloat16 beta,
phi::dtype::bfloat16 *C,
int ldc UNUSED) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个UNUSED应该删掉吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

context_.GetComputeCapability(),
80,
phi::errors::InvalidArgument(
"rocblas fp16 gemm requires GPU compute capability >= 80,"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是bf16把

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

ZzSean pushed a commit to ZzSean/Paddle that referenced this pull request May 5, 2023
@co63oc co63oc deleted the addmm branch May 11, 2023 08:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants