Skip to content

Commit

Permalink
modified according to PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbopd committed Oct 25, 2022
1 parent 1873948 commit cd14eb8
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions paddle/phi/kernels/primitive/compute_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
__shared__ T shared[2 * kWarpSize];
int block_dim_x = blockDim.x;
if (blockDim.x > kWarpSize) {
int lane, tid, wid, bid, n;
// Bit operation can be used when kWarpSize is 32 or 64 now
n = kWarpSize == 32 ? 5 : 6;
block_dim_x = blockDim.x >> n;
lane = threadIdx.x & (kWarpSize - 1);
tid = threadIdx.y * blockDim.x + threadIdx.x;
wid = tid >> n;
bid = threadIdx.y;
constexpr int rshift_val =
(kWarpSize != 32) ? ((kWarpSize == 64) ? 6 : 5) : 5;
block_dim_x = blockDim.x >> rshift_val;
int lane = threadIdx.x & (kWarpSize - 1);
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int wid = tid >> rshift_val;
int bid = threadIdx.y;
val = WarpReduce(val, reducer);
if (lane == 0) {
shared[wid] = val;
Expand Down

0 comments on commit cd14eb8

Please sign in to comment.