Skip to content

Commit

Permalink
add Cumsum
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG committed Feb 15, 2022
1 parent b37a300 commit 932b7ba
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions paddle/pten/kernels/primitive/compute_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -443,5 +443,43 @@ __device__ __forceinline__ void ElementwiseRandom(OutT* out,
}
}

// attention please set share_size = blockDim.x;
// data and b are the register pointer
#define shared_size 64
template <typename InT,
typename OutT,
int NX,
int NY,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void Cumsum(OutT* out,
const InT* in,
OpFunc compute) {
__shared__ InT temp[shared_size * 2 + (shared_size * 2) / 32];
int tidx = threadIdx.x;
temp[tidx + tidx / 32] = in[0];
temp[shared_size + tidx + (shared_size + tidx) / 32] = in[1];
for (int stride = 1; stride <= blockDim.x; stride *= 2) {
__syncthreads();
int index = (tidx + 1) * 2 * stride - 1;
if (index < (blockDim.x * 2)) {
temp[index + index / 32] += temp[index - stride + (index - stride) / 32];
}
}
for (int stride = (blockDim.x * 2) / 4; stride > 0; stride /= 2) {
__syncthreads();
int index = (tidx + 1) * 2 * stride - 1;
if ((index + stride) < (blockDim.x * 2)) {
temp[index + stride + (stride + index) / 32] +=
temp[index + (index) / 32];
}
}

__syncthreads();
out[0] = static_cast<OutT>(temp[tidx + tidx / 32]);
out[1] =
static_cast<OutT>(temp[tidx + shared_size + (tidx + shared_size) / 32]);
}

} // namespace kps
} // namespace pten

1 comment on commit 932b7ba

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.