Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modified the elementwise_op_broadcast and elementwise_op_impl for xpu2 #37226

Merged
merged 2 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions paddle/fluid/operators/kernel_primitives/kernel_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,45 @@
// limitations under the License.

#pragma once
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/helper_primitives.h"
#ifdef PADDLE_WITH_XPU2
#include "paddle/fluid/operators/kernel_primitives/compute_primitives_xpu2.h"
#include "paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h"
#define THREAD_ID_X core_id()
AnnaTrainingG marked this conversation as resolved.
Show resolved Hide resolved
#define THREAD_ID_Y 0
#define THREAD_ID_Z 0

#define BLOCK_NUM_X core_num()
#define BLOCK_NUM_Y 0
#define BLOCK_NUM_Z 0

#define BLOCK_ID_X cluster_id()
#define BLOCK_ID_Y 0
#define BLOCK_ID_Z 0

#define GRID_NUM_X cluster_num()
#define GRID_NUM_Y 0
#define GRID_NUM_Z 0
#else
#include "paddle/fluid/operators/kernel_primitives/compute_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/datamover_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/helper_primitives.h"
#define THREAD_ID_X threadIdx.x
#define THREAD_ID_Y threadIdx.y
#define THREAD_ID_Z threadIdx.z

#define BLOCK_NUM_X blockDim.x
#define BLOCK_NUM_Y blockDim.y
#define BLOCK_NUM_Z blockDim.z

#define BLOCK_ID_X blockIdx.x
#define BLOCK_ID_Y blockIdx.y
#define BLOCK_ID_Z blockIdx.z

#define GRID_NUM_X gridDim.x
#define GRID_NUM_Y gridDim.y
#define GRID_NUM_Z gridDim.z
#endif

namespace paddle {
namespace operators {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,20 +196,19 @@ template <typename InT,
int VecSize,
int Rank,
bool IsBoundary = false>
__device__ void DealSegment(
__device__ void ElementwiseBroadcastKernelImpl(
const paddle::framework::Array<const InT *__restrict__, Arity> &ins,
OutT *out,
const paddle::framework::Array<bool, Arity> &use_broadcast,
uint32_t numel,
const paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
&configs,
int num,
int block_offset,
AnnaTrainingG marked this conversation as resolved.
Show resolved Hide resolved
Functor func) {
InT args[Arity][VecSize];
OutT result[VecSize];

int block_offset = blockIdx.x * blockDim.x * VecSize;

#pragma unroll
for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
Expand Down Expand Up @@ -240,27 +239,73 @@ template <typename InT,
int Arity,
int VecSize,
int Rank>
__global__ void BroadcastKernel(
__global__ void ElementwiseBroadcastKernel(
paddle::framework::Array<const InT *__restrict__, Arity> ins,
OutT *out,
paddle::framework::Array<bool, Arity> use_broadcast,
uint32_t numel,
paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
configs,
int main_tid,
int main_offset,
int tail_tid,
Functor func) {
int block_offset = blockIdx.x * blockDim.x * VecSize;
// data offset of this block
if (blockIdx.x < main_tid) {
int num = blockDim.x * VecSize; // blockIdx.x < main_tid
pten::DealSegment<InT, OutT, Functor, Arity, VecSize, Rank, false>(
ins, out, use_broadcast, numel, configs, num, func);
} else { // reminder
int num = tail_tid;
pten::DealSegment<InT, OutT, Functor, Arity, VecSize, Rank, true>(
ins, out, use_broadcast, numel, configs, num, func);
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
#ifdef PADDLE_WITH_XPU2
for (; block_offset < main_offset; block_offset += stride) {
ElementwiseBroadcastKernelImpl<InT,
OutT,
Functor,
Arity,
VecSize,
Rank,
false>(ins,
out,
use_broadcast,
numel,
configs,
BLOCK_NUM_X * VecSize,
block_offset,
func);
}
if (block_offset < numel) {
ElementwiseBroadcastKernelImpl<InT,
OutT,
Functor,
Arity,
VecSize,
Rank,
true>(
ins, out, use_broadcast, numel, configs, tail_tid, block_offset, func);
}

#else
if (block_offset < main_offset) {
ElementwiseBroadcastKernelImpl<InT,
OutT,
Functor,
Arity,
VecSize,
Rank,
false>(ins,
out,
use_broadcast,
numel,
configs,
BLOCK_NUM_X * VecSize,
block_offset,
func);
} else {
ElementwiseBroadcastKernelImpl<InT,
OutT,
Functor,
Arity,
VecSize,
Rank,
true>(
ins, out, use_broadcast, numel, configs, tail_tid, block_offset, func);
}
#endif
}

template <typename InT,
Expand All @@ -278,7 +323,7 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;

int main_tid = numel / (VecSize * threads);
int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
int tail_tid = numel % (VecSize * threads);
auto stream = ctx.stream();
OutT *out_data = out->mutable_data<OutT>();
Expand All @@ -298,20 +343,40 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
}
}

BroadcastKernel<InT,
OutT,
Functor,
Arity,
VecSize,
Rank><<<blocks, threads, 0, stream>>>(ins_data,
out_data,
use_broadcast,
numel,
configs,
main_tid,
tail_tid,
func);
#ifdef PADDLE_WITH_XPU2
AnnaTrainingG marked this conversation as resolved.
Show resolved Hide resolved
threads = 128;
blocks = 8;
main_offset = (numel / (VecSize * threads)) * VecSize * threads;
tail_tid = numel % (VecSize * threads);
ElementwiseBroadcastKernel<InT,
OutT,
Functor,
Arity,
VecSize,
Rank><<<blocks, threads, stream>>>(ins_data,
out_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
#else
ElementwiseBroadcastKernel<InT,
OutT,
Functor,
Arity,
VecSize,
Rank><<<blocks, threads, 0, stream>>>(
ins_data,
out_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
#endif
}

template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,15 @@ template <typename InT,
int Arity,
int VecSize,
bool IsBoundary>
__device__ void DealSegment(
__device__ void VectorizedElementwiseKernelImpl(
const paddle::framework::Array<const InT *__restrict__, Arity> &in,
OutT *out,
int num,
int data_offset,
Functor func) {
InT args[Arity][VecSize];
OutT result[VecSize];

int data_offset = VecSize * blockIdx.x * blockDim.x;

#pragma unroll
for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
Expand All @@ -87,18 +86,23 @@ __device__ void DealSegment(
}

template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
__global__ void ElementVectorizeKernel(
__global__ void VectorizedElementwiseKernel(
paddle::framework::Array<const InT *__restrict__, Arity> ins,
OutT *out,
int size,
int main_offset,
Functor func) {
int data_offset = VecSize * blockIdx.x * blockDim.x;
int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
for (; data_offset < main_offset; data_offset += stride) {
VectorizedElementwiseKernelImpl<InT, OutT, Functor, Arity, VecSize, false>(
ins, out, VecSize * BLOCK_NUM_X, data_offset, func);
}

int num = size - data_offset;
// the num this time have to deal with
if (VecSize * blockDim.x > num) { // reminder segment
DealSegment<InT, OutT, Functor, Arity, VecSize, true>(ins, out, num, func);
} else { // complete segment
DealSegment<InT, OutT, Functor, Arity, VecSize, false>(ins, out, num, func);
if (num > 0) {
VectorizedElementwiseKernelImpl<InT, OutT, Functor, Arity, VecSize, true>(
ins, out, num, data_offset, func);
}
}

Expand Down Expand Up @@ -132,12 +136,25 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx,
for (int i = 0; i < Arity; i++) {
ins_data[i] = ins[i]->data<InT>();
}
ElementVectorizeKernel<InT,
OutT,
Functor,
Arity,
VecSize><<<grid_size, block_size, 0, stream>>>(
ins_data, out_data, numel, func);
#ifdef PADDLE_WITH_XPU2
block_size = 128;
AnnaTrainingG marked this conversation as resolved.
Show resolved Hide resolved
grid_size = 8;
int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size;
VectorizedElementwiseKernel<InT,
OutT,
Functor,
Arity,
VecSize><<<grid_size, block_size, 0, stream>>>(
ins_data, out_data, numel, main_offset, func);
#else
int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size;
VectorizedElementwiseKernel<InT,
OutT,
Functor,
Arity,
VecSize><<<grid_size, block_size, 0, stream>>>(
ins_data, out_data, numel, main_offset, func);
#endif
}

template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
Expand Down