diff --git a/paddle/phi/kernels/primitive/compute_primitives.h b/paddle/phi/kernels/primitive/compute_primitives.h index 7c94e126ff1e1..b3da41976624b 100644 --- a/paddle/phi/kernels/primitive/compute_primitives.h +++ b/paddle/phi/kernels/primitive/compute_primitives.h @@ -91,14 +91,14 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) { __shared__ T shared[2 * kWarpSize]; int block_dim_x = blockDim.x; if (blockDim.x > kWarpSize) { - int lane, tid, wid, bid, n; // Bit operation can be used when kWarpSize is 32 or 64 now - n = kWarpSize == 32 ? 5 : 6; - block_dim_x = blockDim.x >> n; - lane = threadIdx.x & (kWarpSize - 1); - tid = threadIdx.y * blockDim.x + threadIdx.x; - wid = tid >> n; - bid = threadIdx.y; + constexpr int rshift_val = + (kWarpSize != 32) ? ((kWarpSize == 64) ? 6 : 5) : 5; + block_dim_x = blockDim.x >> rshift_val; + int lane = threadIdx.x & (kWarpSize - 1); + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int wid = tid >> rshift_val; + int bid = threadIdx.y; val = WarpReduce(val, reducer); if (lane == 0) { shared[wid] = val;