Skip to content

Commit

Permalink
Fix the race condition in cumsum operator (#42205)
Browse files Browse the repository at this point in the history
* Fix the race condition in cumsum operator

* Optimize cumsum operator
  • Loading branch information
leo0519 authored and wawltor committed May 5, 2022
1 parent a391762 commit ff455c7
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions paddle/phi/kernels/gpu/cumsum_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,12 @@ __device__ void BlockReverse(
int tx = threadIdx.x;

int offset = tx;
int in_index = src_base + offset;
if (offset >= valid_item) {
sh_mem[offset] = 0;
} else {
int sh_mem_index = BLOCK_SIZE - offset - 1;
T data = idata[in_index];
sh_mem[sh_mem_index] = data;
T src_data = 0;
int src_offset = BLOCK_SIZE - offset - 1;
if (src_offset < valid_item) {
src_data = idata[src_base + src_offset];
}
sh_mem[offset] = src_data;

__syncthreads();
int out_index = dst_base - offset;
Expand Down

1 comment on commit ff455c7

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on ff455c7 May 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 PR: #42500 Commit ID: ff455c7 contains failed CI.

🔹 Failed: PR-CI-iScan-Python

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-iScan-C

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-Kunlun-KP-Build

Unknown Failed
Unknown Failed

Please sign in to comment.