Skip to content

Commit

Permalink
use elementwise to optimize gelu backward implementation on GPU (#38263)
Browse files Browse the repository at this point in the history
* optimize gelu backward

* optimize gelu backward

* optimize code

* Number to expression

* Replacement number
  • Loading branch information
Zjq9409 authored Dec 22, 2021
1 parent d48d712 commit 858e435
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 12 deletions.
76 changes: 70 additions & 6 deletions paddle/fluid/operators/gelu_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/gelu_op.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {
Expand All @@ -27,9 +26,11 @@ struct GeluWithApproximateFunctor {
// this function is tanh approximation of gelu
MPType x = static_cast<MPType>(arg_x);
MPType one = static_cast<MPType>(1);
MPType out = x * static_cast<MPType>(0.5) *
(one + tanh(static_cast<MPType>(0.79788456) * x *
(one + static_cast<MPType>(0.044715) * x * x)));
MPType half = static_cast<MPType>(0.5);
MPType kAlpha = static_cast<MPType>(M_2_SQRTPI * M_SQRT1_2);
auto tanh_out =
tanh(kAlpha * x * (one + static_cast<MPType>(GELU_CONSTANT) * x * x));
MPType out = x * half * (one + tanh_out);
return static_cast<T>(out);
}
};
Expand All @@ -40,9 +41,10 @@ struct GeluWithoutApproximateFunctor {
inline HOSTDEVICE T operator()(T arg_x) {
// actual gelu with approximation = false
MPType x = static_cast<MPType>(arg_x);
MPType one = static_cast<MPType>(1);
MPType half = static_cast<MPType>(0.5);
MPType erf_out = erf(x * static_cast<MPType>(M_SQRT1_2));
MPType out =
x * static_cast<MPType>(0.5) * (static_cast<MPType>(1) + erf_out);
MPType out = x * half * (one + erf_out);
return static_cast<T>(out);
}
};
Expand Down Expand Up @@ -71,6 +73,68 @@ class GeluKernel<platform::CUDADeviceContext, T>
}
};

template <typename T>
struct GeluWithApproximateGradFunctor {
using MPType = typename details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T arg_x, T arg_dout) {
MPType x = static_cast<MPType>(arg_x);
MPType dout = static_cast<MPType>(arg_dout);
MPType one = static_cast<MPType>(1);
MPType half = static_cast<MPType>(0.5);
MPType kAlpha = static_cast<MPType>(M_2_SQRTPI * M_SQRT1_2);
MPType kBeta =
kAlpha * static_cast<MPType>(GELU_CONSTANT) * static_cast<MPType>(3);
auto cube_x = x * x * x;
auto tanh_out =
tanh(kAlpha * ((static_cast<MPType>(GELU_CONSTANT) * cube_x) + x));
auto ans =
half * (one + tanh_out +
(one - tanh_out * tanh_out) * (x * kAlpha + kBeta * cube_x));
return static_cast<T>(ans * dout);
}
};

template <typename T>
struct GeluWithoutApproximateGradFunctor {
using MPType = typename details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T arg_x, T arg_dout) {
MPType x = static_cast<MPType>(arg_x);
MPType dout = static_cast<MPType>(arg_dout);
MPType one = static_cast<MPType>(1);
MPType half = static_cast<MPType>(0.5);
MPType kAlpha = static_cast<MPType>(M_2_SQRTPI * M_SQRT1_2);
auto ans = half * (one + erf(x * static_cast<MPType>(M_SQRT1_2))) +
half * kAlpha * x * exp(-half * x * x);
return static_cast<T>(ans * dout);
}
};

template <typename T>
class GeluGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* dout =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto approximate = context.Attr<bool>("approximate");
dx->mutable_data<T>(dout->place());

std::vector<const framework::Tensor*> ins = {x, dout};
std::vector<framework::Tensor*> outs = {dx};
const auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
if (approximate) {
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithApproximateGradFunctor<T>());
} else {
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithoutApproximateGradFunctor<T>());
}
}
};

} // namespace operators
} // namespace paddle

Expand Down
15 changes: 9 additions & 6 deletions paddle/fluid/operators/gelu_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ limitations under the License. */
namespace paddle {
namespace operators {

#define GELU_CONSTANT 0.044715

template <typename T>
struct GeluFunctor {
template <typename Device, typename X, typename Out>
Expand All @@ -41,14 +43,14 @@ struct GeluFunctor {
auto casted_x = x.template cast<float>();
auto temp =
(static_cast<float>(M_2_SQRTPI * M_SQRT1_2) *
(casted_x + static_cast<float>(0.044715) * casted_x.cube()))
(casted_x + static_cast<float>(GELU_CONSTANT) * casted_x.cube()))
.tanh();
out.device(d) = (casted_x * static_cast<float>(0.5) *
(static_cast<float>(1) + temp))
.template cast<T>();
} else {
auto temp = (static_cast<T>(M_2_SQRTPI * M_SQRT1_2) *
(x + static_cast<T>(0.044715) * x.cube()))
(x + static_cast<T>(GELU_CONSTANT) * x.cube()))
.tanh();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
}
Expand Down Expand Up @@ -101,10 +103,10 @@ struct GeluGradFunctor {

const float kAlpha = static_cast<float>(M_2_SQRTPI * M_SQRT1_2);
const float kBeta =
kAlpha * static_cast<float>(0.044715) * static_cast<float>(3);
kAlpha * static_cast<float>(GELU_CONSTANT) * static_cast<float>(3);
const auto y =
(kAlpha *
((static_cast<float>(0.044715) * casted_x.cube()) + casted_x))
((static_cast<float>(GELU_CONSTANT) * casted_x.cube()) + casted_x))
.tanh();
dx.device(d) = (static_cast<float>(0.5) * casted_dout *
(static_cast<float>(1) + y +
Expand All @@ -113,9 +115,10 @@ struct GeluGradFunctor {
.template cast<T>();
} else {
const T kAlpha = static_cast<T>(M_2_SQRTPI * M_SQRT1_2);
const T kBeta = kAlpha * static_cast<T>(0.044715) * static_cast<T>(3);
const T kBeta =
kAlpha * static_cast<T>(GELU_CONSTANT) * static_cast<T>(3);
const auto y =
(kAlpha * ((static_cast<T>(0.044715) * x.cube()) + x)).tanh();
(kAlpha * ((static_cast<T>(GELU_CONSTANT) * x.cube()) + x)).tanh();
dx.device(d) = static_cast<T>(0.5) * dout *
(static_cast<T>(1) + y +
(x - x * y.square()) * (kAlpha + kBeta * x.square()));
Expand Down

0 comments on commit 858e435

Please sign in to comment.