-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate rmsnorm kernel #54998
Integrate rmsnorm kernel #54998
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
这个算子无法支持对于加上一次gemm的Bias和residual,所以我在想是否有可能支持load/store的可配置性 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
#pragma once | ||
|
||
#include "paddle/phi/core/dense_tensor.h" | ||
#include "paddle/phi/core/selected_rows.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
selected_rows.h头文件看上去没有用到,可以去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,相关Comment统一下一个PR refine
template <typename T, typename Context> | ||
void RmsNormWrapper(const Context& ctx, | ||
const T* x, | ||
const T* weight, | ||
const T* bias, | ||
const float epsilon, | ||
const int rows, | ||
const int cols, | ||
T* output); | ||
|
||
template <typename T, typename Context> | ||
void ResidualAddRmsNormWrapper(const Context& ctx, | ||
const T* x, | ||
const T* residual, | ||
const T* bias, | ||
const T* norm_weight, | ||
const T* norm_bias, | ||
const float epsilon, | ||
const int rows, | ||
const int cols, | ||
T* residual_output, | ||
T* output); | ||
|
||
template <typename T, typename Context> | ||
void RmsNormInt8OutWrapper(const Context& ctx, | ||
const T* x, | ||
const T* weight, | ||
const T* bias, | ||
const float epsilon, | ||
const int rows, | ||
const int cols, | ||
const float in_scale, | ||
const int quant_round_type, | ||
const float quant_max_bound, | ||
const float quant_min_bound, | ||
int8_t* output); | ||
|
||
template <typename T, typename Context> | ||
void ResidualAddRmsNormInt8OutWrapper(const Context& ctx, | ||
const T* x, | ||
const T* residual, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这几个Wrapper函数声明在头文件中好像没有起到作用,可以去掉吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这几个wrapper后续会在一个NormHelper里去使用(如果后续可以去掉,下一个PR我会去除)
* add rmsnorm kernel * add static graph test * fix round type * use alignas to avoid msvc compile error * remove redundant headerfile to avoid rocm compile error * fix rocm compile not found cub * Add document
* add rmsnorm kernel * add static graph test * fix round type * use alignas to avoid msvc compile error * remove redundant headerfile to avoid rocm compile error * fix rocm compile not found cub * Add document
PR types
New features
PR changes
OPs
Description
Integrate RMSNorm CUDA Kernel
Support Residual Load, Int8Out, change use Single Pass for Inference Speed.
Pcard-72603