Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the call of some elementwise kernel to decrease the static library size. #57838

Merged
merged 6 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 8 additions & 25 deletions paddle/phi/kernels/gpu/lerp_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,11 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"

namespace phi {

template <typename T>
struct BroadcastMinElementWiseDirectCUDAFunctor {
HOSTDEVICE inline T operator()(const T min) const { return min; }
};

template <typename T>
struct LerpElementWiseDirectCUDAFunctor {
HOSTDEVICE inline T operator()(const T x, const T y, const T weight) const {
Expand Down Expand Up @@ -87,36 +83,23 @@ void LerpKernel(const Context &ctx,
DenseTensor b_min = phi::EmptyLike<T>(ctx, *out);
if (x.dims().size() != y.dims().size() &&
weight.dims().size() != y.dims().size()) {
std::vector<const DenseTensor *> broadcast_min_inputs;
broadcast_min_inputs.reserve(1);
std::vector<DenseTensor *> broadcast_min_outputs = {&b_min};
auto broadcast_min_functor =
BroadcastMinElementWiseDirectCUDAFunctor<T>();
if (x.dims().size() < y.dims().size() &&
x.dims().size() < weight.dims().size()) {
broadcast_min_inputs.emplace_back(&x);
phi::funcs::BroadcastKernel<T>(ctx,
broadcast_min_inputs,
&broadcast_min_outputs,
broadcast_min_functor);
// x broadcast to b_min
ExpandKernel<T, Context>(ctx, x, phi::vectorize(b_min.dims()), &b_min);
inputs.emplace_back(&b_min);
inputs.emplace_back(&y);
inputs.emplace_back(&weight);
} else if (y.dims().size() < weight.dims().size()) {
broadcast_min_inputs.emplace_back(&y);
phi::funcs::BroadcastKernel<T>(ctx,
broadcast_min_inputs,
&broadcast_min_outputs,
broadcast_min_functor);
// y broadcast to b_min
ExpandKernel<T, Context>(ctx, y, phi::vectorize(b_min.dims()), &b_min);
inputs.emplace_back(&x);
inputs.emplace_back(&b_min);
inputs.emplace_back(&weight);
} else {
broadcast_min_inputs.emplace_back(&weight);
phi::funcs::BroadcastKernel<T>(ctx,
broadcast_min_inputs,
&broadcast_min_outputs,
broadcast_min_functor);
// weight broadcast to b_min
ExpandKernel<T, Context>(
ctx, weight, phi::vectorize(b_min.dims()), &b_min);
inputs.emplace_back(&x);
inputs.emplace_back(&y);
inputs.emplace_back(&b_min);
Expand Down
85 changes: 40 additions & 45 deletions paddle/phi/kernels/gpu/viterbi_decode_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ namespace cub = hipcub;
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/compare_functors.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
Expand Down Expand Up @@ -80,21 +82,6 @@ int64_t ComputeBlockSize(int64_t col) {
return 8;
}

template <typename Context,
template <typename T>
typename BinaryFunctor,
typename T>
struct BinaryOperation {
void operator()(const Context& dev_ctx,
const DenseTensor& lhs,
const DenseTensor& rhs,
DenseTensor* output) {
std::vector<const DenseTensor*> ins{&lhs, &rhs};
std::vector<DenseTensor*> outs{output};
phi::funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, BinaryFunctor<T>(), 0);
}
};

template <typename Context,
template <typename InT, typename OutT>
typename CompareFunctor,
Expand Down Expand Up @@ -314,47 +301,46 @@ void ViterbiDecodeKernel(const Context& dev_ctx,
start_trans.Resize({1, n_labels});
auto logit0 = input_exp.Slice(0, 1);
logit0.Resize({batch_size, n_labels});
BinaryOperation<Context, phi::funcs::AddFunctor, T> AddFloat;
BinaryOperation<Context, phi::funcs::AddFunctor, int64_t> AddInt;
BinaryOperation<Context, phi::funcs::MultiplyFunctor, T> MulFloat;
BinaryOperation<Context, phi::funcs::MultiplyFunctor, int64_t> MulInt;
BinaryOperation<Context, phi::funcs::SubtractFunctor, T> SubFloat;
BinaryOperation<Context, phi::funcs::SubtractFunctor, int64_t> SubInt;
if (include_bos_eos_tag) {
AddFloat(dev_ctx, logit0, start_trans, &alpha);
phi::AddKernel<T, Context>(dev_ctx, logit0, start_trans, &alpha);
GetMask<Context, phi::funcs::EqualFunctor, T>()(
dev_ctx, left_length, one, &float_mask);
MulFloat(dev_ctx, stop_trans, float_mask, &alpha_nxt);
AddFloat(dev_ctx, alpha, alpha_nxt, &alpha);
phi::MultiplyKernel<T, Context>(
dev_ctx, stop_trans, float_mask, &alpha_nxt);
phi::AddKernel<T, Context>(dev_ctx, alpha, alpha_nxt, &alpha);
} else {
alpha = logit0;
}
SubInt(dev_ctx, left_length, one, &left_length);
phi::SubtractKernel<int64_t, Context>(
dev_ctx, left_length, one, &left_length);
Argmax<Context, T, int64_t> argmax;
for (int64_t i = 1; i < max_seq_len; ++i) {
DenseTensor logit = input_exp.Slice(i, i + 1);
logit.Resize({batch_size, n_labels});
DenseTensor& alpha_exp = alpha.Resize({batch_size, n_labels, 1});
AddFloat(dev_ctx, alpha_exp, trans_exp, &alpha_trn_sum);
phi::AddKernel<T, Context>(dev_ctx, alpha_exp, trans_exp, &alpha_trn_sum);
auto alpha_argmax_temp = alpha_argmax_unbind[i - 1];
alpha_argmax_temp.Resize({batch_size, n_labels});
argmax(dev_ctx, alpha_trn_sum, &alpha_argmax_temp, &alpha_max, 1);
historys.emplace_back(alpha_argmax_temp);
AddFloat(dev_ctx, alpha_max, logit, &alpha_nxt);
phi::AddKernel<T, Context>(dev_ctx, alpha_max, logit, &alpha_nxt);
alpha.Resize({batch_size, n_labels});
GetMask<Context, phi::funcs::GreaterThanFunctor, T>()(
dev_ctx, left_length, zero, &float_mask);
MulFloat(dev_ctx, alpha_nxt, float_mask, &alpha_nxt);
SubFloat(dev_ctx, float_one, float_mask, &float_mask);
MulFloat(dev_ctx, alpha, float_mask, &alpha);
AddFloat(dev_ctx, alpha, alpha_nxt, &alpha);
phi::MultiplyKernel<T, Context>(dev_ctx, alpha_nxt, float_mask, &alpha_nxt);
phi::SubtractKernel<T, Context>(
dev_ctx, float_one, float_mask, &float_mask);
phi::MultiplyKernel<T, Context>(dev_ctx, alpha, float_mask, &alpha);
phi::AddKernel<T, Context>(dev_ctx, alpha, alpha_nxt, &alpha);
if (include_bos_eos_tag) {
GetMask<Context, phi::funcs::EqualFunctor, T>()(
dev_ctx, left_length, one, &float_mask);
MulFloat(dev_ctx, stop_trans, float_mask, &alpha_nxt);
AddFloat(dev_ctx, alpha, alpha_nxt, &alpha);
phi::MultiplyKernel<T, Context>(
dev_ctx, stop_trans, float_mask, &alpha_nxt);
phi::AddKernel<T, Context>(dev_ctx, alpha, alpha_nxt, &alpha);
}
SubInt(dev_ctx, left_length, one, &left_length);
phi::SubtractKernel<int64_t, Context>(
dev_ctx, left_length, one, &left_length);
}
argmax(dev_ctx, alpha, &last_ids, scores, 1);
left_length.Resize({batch_size});
Expand All @@ -363,32 +349,41 @@ void ViterbiDecodeKernel(const Context& dev_ctx,
// last_ids_update = last_ids * tag_mask
int last_ids_index = 1;
int actual_len = (std::min)(seq_len, static_cast<int>(max_seq_len));
MulInt(dev_ctx, last_ids, int_mask, &batch_path[actual_len - last_ids_index]);
phi::MultiplyKernel<int64_t, Context>(
dev_ctx, last_ids, int_mask, &batch_path[actual_len - last_ids_index]);
// The algorithm below can refer to
// https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/layers/crf.py#L438
ARange<Context> arange;
arange(dev_ctx, batch_offset.data<int64_t>(), batch_size, n_labels);
Gather<Context, int64_t, int64_t> gather;
for (auto hist = historys.rbegin(); hist != historys.rend(); ++hist) {
++last_ids_index;
AddInt(dev_ctx, left_length, one, &left_length);
AddInt(dev_ctx, batch_offset, last_ids, &gather_idx);
phi::AddKernel<int64_t, Context>(dev_ctx, left_length, one, &left_length);
phi::AddKernel<int64_t, Context>(
dev_ctx, batch_offset, last_ids, &gather_idx);
DenseTensor& last_ids_update = batch_path[actual_len - last_ids_index];
hist->Resize({batch_size * n_labels});
gather(dev_ctx, *hist, gather_idx, &last_ids_update);
GetMask<Context, phi::funcs::GreaterThanFunctor, int64_t>()(
dev_ctx, left_length, zero, &int_mask);
MulInt(dev_ctx, last_ids_update, int_mask, &last_ids_update);
phi::MultiplyKernel<int64_t, Context>(
dev_ctx, last_ids_update, int_mask, &last_ids_update);
GetMask<Context, phi::funcs::EqualFunctor, int64_t>()(
dev_ctx, left_length, zero, &zero_len_mask);
MulInt(dev_ctx, last_ids, zero_len_mask, &last_ids_tmp);
SubInt(dev_ctx, one, zero_len_mask, &zero_len_mask);
MulInt(dev_ctx, last_ids_update, zero_len_mask, &last_ids_update);
AddInt(dev_ctx, last_ids_update, last_ids_tmp, &last_ids_update);
phi::MultiplyKernel<int64_t, Context>(
dev_ctx, last_ids, zero_len_mask, &last_ids_tmp);
phi::SubtractKernel<int64_t, Context>(
dev_ctx, one, zero_len_mask, &zero_len_mask);
phi::MultiplyKernel<int64_t, Context>(
dev_ctx, last_ids_update, zero_len_mask, &last_ids_update);
phi::AddKernel<int64_t, Context>(
dev_ctx, last_ids_update, last_ids_tmp, &last_ids_update);
GetMask<Context, phi::funcs::LessThanFunctor, int64_t>()(
dev_ctx, left_length, zero, &int_mask);
MulInt(dev_ctx, last_ids, int_mask, &last_ids);
AddInt(dev_ctx, last_ids_update, last_ids, &last_ids);
phi::MultiplyKernel<int64_t, Context>(
dev_ctx, last_ids, int_mask, &last_ids);
phi::AddKernel<int64_t, Context>(
dev_ctx, last_ids_update, last_ids, &last_ids);
}
TransposeKernel<int64_t, Context>(dev_ctx, tpath, {1, 0}, path);
}
Expand Down
9 changes: 2 additions & 7 deletions paddle/phi/kernels/impl/elementwise_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,8 @@ namespace phi {
const DenseTensor& y, \
int axis, \
DenseTensor* out) { \
std::vector<const DenseTensor*> inputs; \
inputs.reserve(2); \
std::vector<DenseTensor*> outputs; \
outputs.reserve(1); \
inputs.emplace_back(&x); \
inputs.emplace_back(&y); \
outputs.emplace_back(out); \
std::vector<const DenseTensor*> inputs = {&x, &y}; \
std::vector<DenseTensor*> outputs = {out}; \
dev_ctx.template Alloc<T>(out); \
funcs::BroadcastKernel<T>( \
dev_ctx, inputs, &outputs, funcs::name##Functor<T>(), axis); \
Expand Down
61 changes: 30 additions & 31 deletions paddle/phi/kernels/kps/compare_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/impl/compare_kernel_impl.h"
#include "paddle/phi/kernels/funcs/compare_functors.h"

#ifdef PADDLE_WITH_XPU_KP
#include "paddle/phi/backends/xpu/xpu_context.h"
Expand All @@ -27,6 +27,7 @@
#include "paddle/phi/kernels/compare_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/legacy/compare_kernel.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#endif

Expand All @@ -43,37 +44,27 @@ struct BitwiseAdd {
}
};

template <typename T,
typename Context,
typename Functor,
typename InverseFunctor>
inline void CompareKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
ctx.template Alloc<bool>(out);
std::vector<const DenseTensor*> ins{&x, &y};
std::vector<DenseTensor*> outs{out};
funcs::BroadcastKernel<bool>(ctx, ins, &outs, Functor(), axis);
}
#define DEFINE_CUDA_COMPARE_KERNEL(name) \
template <typename T, typename Context> \
void name##Kernel(const Context& ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
if (out->IsSharedWith(x)) { \
auto x_origin = x; \
name##RawKernel<T, Context>(ctx, x_origin, y, -1, out); \
} else { \
name##RawKernel<T, Context>(ctx, x, y, -1, out); \
} \
}

template <typename T,
typename Context,
typename Functor,
typename InverseFunctor>
inline void InplaceCompareKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
auto x_origin = x;
ctx.template Alloc<bool>(out);
out->set_type(phi::DataType::BOOL);
std::vector<const DenseTensor*> ins{&x_origin, &y};
std::vector<DenseTensor*> outs{out};
funcs::BroadcastKernel<bool>(ctx, ins, &outs, Functor(), axis);
}
DEFINE_CUDA_COMPARE_KERNEL(LessThan)
DEFINE_CUDA_COMPARE_KERNEL(LessEqual)
DEFINE_CUDA_COMPARE_KERNEL(GreaterThan)
DEFINE_CUDA_COMPARE_KERNEL(GreaterEqual)
DEFINE_CUDA_COMPARE_KERNEL(Equal)
DEFINE_CUDA_COMPARE_KERNEL(NotEqual)
#undef DEFINE_CUDA_COMPARE_KERNEL

#ifndef PADDLE_WITH_XPU_KP
template <typename T, typename Context, typename Functor>
Expand Down Expand Up @@ -106,6 +97,14 @@ inline void CompareAllKernelImpl(const Context& ctx,
funcs::ReduceKernel<bool, bool, BitwiseAdd, kps::IdentityFunctor<bool>>(
ctx, tmp, out, kps::IdentityFunctor<bool>(), reduce_dims);
}

template <typename T, typename Context>
void EqualAllKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
CompareAllKernelImpl<T, Context, funcs::EqualFunctor<T>>(ctx, x, y, out);
}
#endif

} // namespace phi
Expand Down
11 changes: 3 additions & 8 deletions paddle/phi/kernels/kps/elementwise_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
#include "paddle/phi/common/float16.h"
#endif
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
#include "paddle/phi/kernels/legacy/elementwise_add_kernel.h"
#include "paddle/phi/kernels/legacy/elementwise_divide_kernel.h"
#include "paddle/phi/kernels/legacy/elementwise_kernel.h"
#include "paddle/phi/kernels/legacy/elementwise_multipy_kernel.h"
#include "paddle/phi/kernels/legacy/elementwise_subtract_kernel.h"
Expand Down Expand Up @@ -146,13 +146,8 @@ void HeavisideKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
inputs.reserve(2);
std::vector<DenseTensor*> outputs;
outputs.reserve(1);
inputs.emplace_back(&x);
inputs.emplace_back(&y);
outputs.emplace_back(out);
std::vector<const DenseTensor*> inputs = {&x, &y};
std::vector<DenseTensor*> outputs = {out};
dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, funcs::ElementwiseHeavisideFunctor<T>());
Expand Down
Loading