Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wint8 gemm and gemv opt #59291

Merged
merged 28 commits into from
Dec 6, 2023

Conversation

wwbitejotunn
Copy link
Contributor

@wwbitejotunn wwbitejotunn commented Nov 23, 2023

PR types

Performance optimization

PR changes

OPs

Description

Pcard-71501

This PR optimize the speed of weight only gemm and gemv gpu kernel.

To speed up the weight-only gemm, following features were adopted

  • using stream-k gemm instead of serial split-k
  • using multi-warp batch gemv
  • gemm/gemv dispatch based on problem size m

For gemms with problem sizes in llama13b, we obtain a 1.34x gemm kernel speed in A100 80G.

image

Copy link

paddle-bot bot commented Nov 23, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@wwbitejotunn wwbitejotunn force-pushed the develop_wint8_gemm_opt branch from 74589d4 to 1402d7c Compare November 24, 2023 07:34
@vivienfanghuagood
Copy link
Contributor

私以为,升级点在PR里可能应该更详细一点~

Copy link
Contributor

@MARD1NO MARD1NO left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Genius!

zhoutianzi666
zhoutianzi666 previously approved these changes Dec 1, 2023
@@ -5163,7 +5163,7 @@ void WeightQuantizeInferMeta(const MetaTensor& x,
out->set_dtype(DataType::INT8);

scale->set_dims(phi::make_ddim(dim_scale));
scale->set_dtype(DataType::FLOAT32);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

以后scale都是fp16了?

Copy link
Contributor Author

@wwbitejotunn wwbitejotunn Dec 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对的, 这边scale改为了bf16/fp16, 能够有更好的性能, 精度应该也能保证, 已经同步修改了weight_quant op中计算scale的部分, 以及paddlenlp代码中scale权重初始化的逻辑

@wwbitejotunn
Copy link
Contributor Author

私以为,升级点在PR里可能应该更详细一点~

已添加了升级点和测试数据~

heavengate
heavengate previously approved these changes Dec 4, 2023
raindrops2sea
raindrops2sea previously approved these changes Dec 4, 2023
Xreki
Xreki previously approved these changes Dec 4, 2023
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for const_cast

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@carryyu carryyu merged commit a8456dc into PaddlePaddle:develop Dec 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants