-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use elementwise to optimize gelu backward implementation on GPU #38263
Conversation
Thanks for your contribution! |
PR名字描述稍微详细一点,跟上个PR对应 |
Done. |
paddle/fluid/operators/gelu_op.cu
Outdated
tanh(kAlpha * x * (one + static_cast<MPType>(0.044715) * x * x)); | ||
auto ans = | ||
half * x * ((one - tanh_out * tanh_out) * | ||
(kAlpha + static_cast<MPType>(0.1070322243) * x * x)) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要出现公式以外的魔鬼数字,都用表达式来代替
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/fluid/operators/gelu_op.cu
Outdated
auto tanh_out = | ||
tanh(kAlpha * x * (one + static_cast<MPType>(0.044715) * x * x)); | ||
auto ans = | ||
half * x * ((one - tanh_out * tanh_out) * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
公式再化简一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/fluid/operators/gelu_op.cu
Outdated
@@ -12,9 +12,76 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. */ | |||
|
|||
#include "paddle/fluid/operators/amp/fp16_type_traits.h" | |||
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" | |||
#include "paddle/fluid/operators/gelu_op.h" | |||
#include "paddle/fluid/platform/float16.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个头文件引用已经在paddle/fluid/operators/amp/fp16_type_traits.h
引用过了,可以删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/fluid/operators/gelu_op.cu
Outdated
(one + tanh(static_cast<MPType>(0.79788456) * x * | ||
(one + static_cast<MPType>(0.044715) * x * x))); | ||
MPType half = static_cast<MPType>(0.5); | ||
MPType decimal = static_cast<MPType>(0.044715); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3})))
gelu的公式,其中“0.044715”是一个固定的常量值
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0.044715
可以用宏定义的方式表达这个常量值, 比如:
#define GELU_CONSTANT 0.044715
同时将这个宏定义放在通用文件夹中,比如gelu_op.h
中,同步修改使用了0.044715
的代码
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/fluid/operators/gelu_op.cu
Outdated
MPType dout = static_cast<MPType>(arg_dout); | ||
MPType one = static_cast<MPType>(1); | ||
MPType half = static_cast<MPType>(0.5); | ||
MPType decimal = static_cast<MPType>(0.044715); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分的魔术数还是存在
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/fluid/operators/gelu_op.cu
Outdated
(one + tanh(static_cast<MPType>(0.79788456) * x * | ||
(one + static_cast<MPType>(0.044715) * x * x))); | ||
MPType half = static_cast<MPType>(0.5); | ||
MPType decimal = static_cast<MPType>(0.044715); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0.044715
可以用宏定义的方式表达这个常量值, 比如:
#define GELU_CONSTANT 0.044715
同时将这个宏定义放在通用文件夹中,比如gelu_op.h
中,同步修改使用了0.044715
的代码
paddle/fluid/operators/gelu_op.cu
Outdated
MPType kBeta = kAlpha * decimal * static_cast<MPType>(3); | ||
auto tanh_out = tanh(kAlpha * x * (one + decimal * x * x)); | ||
auto temp = (one - tanh_out * tanh_out) * (kAlpha + kBeta * x * x); | ||
auto ans = half * x * temp + half * (one + tanh_out); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分的计算中应该存在多次重复计算的参数,比如:x^3
,可以把这类参数挑出来,减少计算量
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…lePaddle#38263) * optimize gelu backward * optimize gelu backward * optimize code * Number to expression * Replacement number
PR types
Performance optimization
PR changes
OPs
Describe
使用elementwise优化gelu算子GPU反向计算,前向计算+反向计算优化后性能数据如下: