From 92e199eedecc7b57b8fd87a137aa5e05f9e2816b Mon Sep 17 00:00:00 2001 From: ZzSean <18818272991@163.com> Date: Wed, 16 Mar 2022 03:54:07 +0000 Subject: [PATCH 1/2] Optimize the computation of log_softmax --- paddle/phi/kernels/gpudnn/softmax_gpudnn.h | 30 ++++++++++------------ 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h index 2b2dd5118969cf..e14e13cb9252d5 100644 --- a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h @@ -121,17 +121,12 @@ struct ReduceMaxFunctor { }; template -struct ExpSubFunctor { - HOSTDEVICE inline ExpSubFunctor() { y = static_cast(0.0f); } - - HOSTDEVICE explicit inline ExpSubFunctor(Tx y) : y((Tx)(y)) {} +struct ExpFunctor { + HOSTDEVICE explicit inline ExpFunctor() {} HOSTDEVICE inline Ty operator()(const Tx& x) const { - return static_cast(std::exp(x - y)); + return static_cast(std::exp(x)); } - - private: - Tx y; }; template @@ -293,9 +288,10 @@ __global__ void WarpSoftmaxForward(T* softmax, } // data src + AccT srcdata_raw[kBatchSize][kLoopsV][kVSize]; AccT srcdata[kBatchSize][kLoopsV][kVSize]; T src_tmp[kBatchSize][kLoopsV][kVSize]; - kps::Init(&srcdata[0][0][0], kLowInf); + kps::Init(&srcdata_raw[0][0][0], kLowInf); kps::Init(&src_tmp[0][0][0], -std::numeric_limits::infinity()); // data dst @@ -317,7 +313,7 @@ __global__ void WarpSoftmaxForward(T* softmax, kps::ReadData( ®_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1); kps::ElementwiseUnary>( - &srcdata[i][0][0], &src_tmp[i][0][0], DataTransFunctor()); + &srcdata_raw[i][0][0], &src_tmp[i][0][0], DataTransFunctor()); } // compute max @@ -327,14 +323,18 @@ __global__ void WarpSoftmaxForward(T* softmax, 1, ReduceMaxFunctor, kMode::kLocalMode>( - &max[0], &srcdata[0][0][0], ReduceMaxFunctor(), true); + &max[0], &srcdata_raw[0][0][0], ReduceMaxFunctor(), true); WarpReduceMax(max); // compute sum #pragma unroll for (int i = 0; i < kBatchSize; ++i) { - kps::ElementwiseUnary>( - &srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor(max[i])); + kps::ElementwiseUnary>( + &srcdata_raw[i][0][0], + &srcdata_raw[i][0][0], + UnarySubFunctor(max[i])); + kps::ElementwiseUnary>( + &srcdata[i][0][0], &srcdata_raw[i][0][0], ExpFunctor()); } kps::Reduce(&softmax[(first_batch + i) * stride]); VecT* reg_v = reinterpret_cast(&out_tmp[i][0][0]); if (LogMode) { - kps::ElementwiseUnary>( - &srcdata[i][0][0], &srcdata[i][0][0], UnaryLogFunctor()); kps::ElementwiseUnary>( &out_tmp[i][0][0], - &srcdata[i][0][0], + &srcdata_raw[i][0][0], UnarySubFunctor(std::log(sum[i]))); } else { kps::ElementwiseUnary>( From bde814b46337790e9ffaa68615cae43c63d04f47 Mon Sep 17 00:00:00 2001 From: ZzSean <18818272991@163.com> Date: Wed, 16 Mar 2022 07:02:47 +0000 Subject: [PATCH 2/2] modify the var name --- paddle/phi/kernels/gpudnn/softmax_gpudnn.h | 33 +++++++++++----------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h index e14e13cb9252d5..77159bfc876da6 100644 --- a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h @@ -122,8 +122,6 @@ struct ReduceMaxFunctor { template struct ExpFunctor { - HOSTDEVICE explicit inline ExpFunctor() {} - HOSTDEVICE inline Ty operator()(const Tx& x) const { return static_cast(std::exp(x)); } @@ -288,11 +286,14 @@ __global__ void WarpSoftmaxForward(T* softmax, } // data src - AccT srcdata_raw[kBatchSize][kLoopsV][kVSize]; - AccT srcdata[kBatchSize][kLoopsV][kVSize]; - T src_tmp[kBatchSize][kLoopsV][kVSize]; - kps::Init(&srcdata_raw[0][0][0], kLowInf); - kps::Init(&src_tmp[0][0][0], -std::numeric_limits::infinity()); + // src_data: the raw data form global memory + // sub_data: store the data obtained by (src_data - max), used by log_softmax + // exp_data: store the data obtained by (exp(sub_data)), used by softmax + T src_data[kBatchSize][kLoopsV][kVSize]; + AccT sub_data[kBatchSize][kLoopsV][kVSize]; + AccT exp_data[kBatchSize][kLoopsV][kVSize]; + kps::Init(&sub_data[0][0][0], kLowInf); + kps::Init(&src_data[0][0][0], -std::numeric_limits::infinity()); // data dst T out_tmp[kBatchSize][kLoopsV][kVSize]; @@ -309,11 +310,11 @@ __global__ void WarpSoftmaxForward(T* softmax, for (int i = 0; i < kBatchSize; ++i) { const VecT* src_v = reinterpret_cast(&src[(first_batch + i) * stride]); - VecT* reg_v = reinterpret_cast(&src_tmp[i][0][0]); + VecT* reg_v = reinterpret_cast(&src_data[i][0][0]); kps::ReadData( ®_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1); kps::ElementwiseUnary>( - &srcdata_raw[i][0][0], &src_tmp[i][0][0], DataTransFunctor()); + &sub_data[i][0][0], &src_data[i][0][0], DataTransFunctor()); } // compute max @@ -323,18 +324,16 @@ __global__ void WarpSoftmaxForward(T* softmax, 1, ReduceMaxFunctor, kMode::kLocalMode>( - &max[0], &srcdata_raw[0][0][0], ReduceMaxFunctor(), true); + &max[0], &sub_data[0][0][0], ReduceMaxFunctor(), true); WarpReduceMax(max); // compute sum #pragma unroll for (int i = 0; i < kBatchSize; ++i) { kps::ElementwiseUnary>( - &srcdata_raw[i][0][0], - &srcdata_raw[i][0][0], - UnarySubFunctor(max[i])); + &sub_data[i][0][0], &sub_data[i][0][0], UnarySubFunctor(max[i])); kps::ElementwiseUnary>( - &srcdata[i][0][0], &srcdata_raw[i][0][0], ExpFunctor()); + &exp_data[i][0][0], &sub_data[i][0][0], ExpFunctor()); } kps::Reduce, kMode::kLocalMode>( - &sum[0], &srcdata[0][0][0], kps::AddFunctor(), true); + &sum[0], &exp_data[0][0][0], kps::AddFunctor(), true); WarpReduceSum(sum); // write data to global memory @@ -354,11 +353,11 @@ __global__ void WarpSoftmaxForward(T* softmax, if (LogMode) { kps::ElementwiseUnary>( &out_tmp[i][0][0], - &srcdata_raw[i][0][0], + &sub_data[i][0][0], UnarySubFunctor(std::log(sum[i]))); } else { kps::ElementwiseUnary>( - &out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor(sum[i])); + &out_tmp[i][0][0], &exp_data[i][0][0], UnaryDivFunctor(sum[i])); } kps::WriteData( &softmax_v[0], ®_v[0], idx_max_v[i], 0, kWarpSize, 1);