Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix reduce_any kernel data race on sharedMem #47233

Merged
merged 6 commits into from
Oct 27, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions paddle/phi/kernels/primitive/compute_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,13 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
__shared__ T shared[2 * kWarpSize];
int block_dim_x = blockDim.x;
if (blockDim.x > kWarpSize) {
block_dim_x = blockDim.x / kWarpSize;
int lane = threadIdx.x % kWarpSize;
// Bit operation can be used when kWarpSize is 32 or 64 now
constexpr int rshift_val =
(kWarpSize != 32) ? ((kWarpSize == 64) ? 6 : 5) : 5;
block_dim_x = blockDim.x >> rshift_val;
int lane = threadIdx.x & (kWarpSize - 1);
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int wid = tid / kWarpSize;
int wid = tid >> rshift_val;
int bid = threadIdx.y;
val = WarpReduce(val, reducer);
if (lane == 0) {
Expand All @@ -110,6 +113,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
val = reducer(val, temp);
}
__syncthreads();
if (threadIdx.x == 0) {
shared[threadIdx.y] = val;
}
Expand Down Expand Up @@ -385,8 +389,8 @@ __device__ __forceinline__ void CycleBinary(OutT* out,
/**
* @brief The Reduce provides collective methods for computing a parallel
* reduction of items partitioned across a CUDA block and intra thread. When
* ReduceMode == kLocalMode, thread reduce along nx. When ReduceMode ==
* kGlobalMode, use shared memory to reduce between threads.
* ReduceMode == kLocalMode, use shared memory to reduce between threads.When
* ReduceMode == kGlobalMode, thread reduce along nx.
*
* @template paraments
* T: The type of data.
Expand Down