Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modified distribution kernel with Kernel Primitive API #39563

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions paddle/fluid/operators/distribution_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ limitations under the License. */
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/core/hostdevice.h"

#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/pten/kernels/primitive/kernel_primitives.h"
#endif

#if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else
Expand Down Expand Up @@ -91,6 +95,8 @@ struct normal_transform {

#if defined(__NVCC__) || defined(__HIPCC__)

namespace kps = pten::kps;

/*********************** Distribution Function *************************/
template <typename T>
struct uniform_distribution;
Expand Down Expand Up @@ -176,25 +182,26 @@ template <typename T, typename DistOp, typename TransformOp>
__global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset,
DistOp dist, TransformOp trans,
T *out_data) {
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
int32_t returns_count = DistOp::kReturnsCount;
size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
static constexpr int kCount = DistOp::kReturnsCount;
#if defined(__NVCC__)
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, offset, &state);
curand_init(seed, idx + THREAD_ID_X, offset, &state);
using SType = curandStatePhilox4_32_10_t;
#else
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, idx, offset, &state);
hiprand_init(seed, idx + THREAD_ID_X, offset, &state);
using SType = hiprandStatePhilox4_32_10_t;
#endif
size_t total_thread = gridDim.x * blockDim.x;
for (size_t i = idx; i < size; i += total_thread * returns_count) {
auto random_tuple = dist(&state);
for (size_t j = 0; j < returns_count; j++) {
size_t index = i + j * total_thread;
if (index < size) {
auto random = (&random_tuple.x)[j];
out_data[index] = static_cast<T>(trans(random));
}
}
size_t total_thread = GRID_NUM_X * BLOCK_NUM_X;
T args[kCount];
T result[kCount];
for (size_t i = idx; i < size; i += total_thread * kCount) {
kps::ElementwiseRandom<SType, T, kCount, 1, DistOp>(&args[0], dist, &state);
kps::ElementwiseUnary<T, T, kCount, 1, 1, TransformOp>(&result[0], &args[0],
trans);
kps::WriteData<T, T, kCount, 1, 1, true>(out_data + i, &result[0], size - i,
1, total_thread, 1);
}
}

Expand Down
53 changes: 53 additions & 0 deletions paddle/pten/kernels/primitive/compute_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,5 +428,58 @@ __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
}
}

template <typename StateType,
typename OutT,
int ReturnsCount,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseRandom(OutT* out,
OpFunc compute,
StateType* state) {
auto random_tuple = compute(state);
#pragma unroll
for (int i = 0; i < ReturnsCount; i++) {
out[i] = static_cast<OutT>((&random_tuple.x)[i]);
}
}

// attention please set share_size = blockDim.x;
// data and b are the register pointer
#define shared_size 64
template <typename InT,
typename OutT,
int NX,
int NY,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void Cumsum(OutT* out,
const InT* in,
OpFunc compute) {
__shared__ InT temp[shared_size * 2 + (shared_size * 2) / 32];
int tidx = threadIdx.x;
temp[tidx + tidx / 32] = in[0];
temp[shared_size + tidx + (shared_size + tidx) / 32] = in[1];
for (int stride = 1; stride <= blockDim.x; stride *= 2) {
__syncthreads();
int index = (tidx + 1) * 2 * stride - 1;
if (index < (blockDim.x * 2)) {
temp[index + index / 32] += temp[index - stride + (index - stride) / 32];
}
}
for (int stride = (blockDim.x * 2) / 4; stride > 0; stride /= 2) {
__syncthreads();
int index = (tidx + 1) * 2 * stride - 1;
if ((index + stride) < (blockDim.x * 2)) {
temp[index + stride + (stride + index) / 32] +=
temp[index + (index) / 32];
}
}

__syncthreads();
out[0] = static_cast<OutT>(temp[tidx + tidx / 32]);
out[1] =
static_cast<OutT>(temp[tidx + shared_size + (tidx + shared_size) / 32]);
}

} // namespace kps
} // namespace pten