-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 4 No.34】为 Paddle 优化 Lerp OP 在 GPU 上的性能 #53154
Changes from all commits
d6dbd12
60e071f
ab512b8
1d63732
2d43b55
f791c5d
09c0042
c73530b
5d02d8a
8fdb5d1
1ad8a27
ac7d1f2
fa83ab1
1bbdeb6
45823ac
81c8610
4171029
ff52b7c
895db15
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,8 +15,115 @@ | |
#include "paddle/phi/kernels/lerp_kernel.h" | ||
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/common/amp_type_traits.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/impl/lerp_kernel_impl.h" | ||
#include "paddle/phi/kernels/empty_kernel.h" | ||
#include "paddle/phi/kernels/funcs/broadcast_function.h" | ||
#include "paddle/phi/kernels/funcs/common_shape.h" | ||
#include "paddle/phi/kernels/funcs/math_function.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T> | ||
struct BroadcastMinElementWiseDirectCUDAFunctor { | ||
HOSTDEVICE inline T operator()(const T min) const { return min; } | ||
}; | ||
|
||
template <typename T> | ||
struct LerpElementWiseDirectCUDAFunctor { | ||
HOSTDEVICE inline T operator()(const T x, const T y, const T weight) const { | ||
return x + weight * (y - x); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct LerpScalarDirectCUDAFunctor { | ||
const T *weight_; | ||
|
||
HOSTDEVICE inline LerpScalarDirectCUDAFunctor(const T *weight) | ||
: weight_(weight) {} | ||
|
||
HOSTDEVICE inline T operator()(const T x, const T y) const { | ||
return x + weight_[0] * (y - x); | ||
} | ||
}; | ||
|
||
template <typename T, typename Context> | ||
void LerpKernel(const Context &ctx, | ||
const DenseTensor &x, | ||
const DenseTensor &y, | ||
const DenseTensor &weight, | ||
DenseTensor *out) { | ||
int rank = out->dims().size(); | ||
PADDLE_ENFORCE_GE( | ||
rank, | ||
0, | ||
phi::errors::InvalidArgument( | ||
"The number of dimensions for LerpOp must be " | ||
"greater than or equal to 0, but the value received is %d.", | ||
rank)); | ||
|
||
ctx.template Alloc<T>(out); | ||
std::vector<DenseTensor *> outputs = {out}; | ||
|
||
std::vector<const DenseTensor *> inputs; | ||
if (weight.numel() == 1) { | ||
const T *weight_ptr = weight.data<T>(); | ||
inputs.reserve(2); | ||
inputs.emplace_back(&x); | ||
inputs.emplace_back(&y); | ||
auto functor = LerpScalarDirectCUDAFunctor<T>(weight_ptr); | ||
phi::funcs::BroadcastKernel<T>(ctx, inputs, &outputs, functor); | ||
} else { | ||
inputs.reserve(3); | ||
auto functor = LerpElementWiseDirectCUDAFunctor<T>(); | ||
DenseTensor b_min = phi::EmptyLike<T>(ctx, *out); | ||
if (x.dims().size() != y.dims().size() && | ||
weight.dims().size() != y.dims().size()) { | ||
std::vector<const DenseTensor *> broadcast_min_inputs; | ||
broadcast_min_inputs.reserve(1); | ||
std::vector<DenseTensor *> broadcast_min_outputs = {&b_min}; | ||
auto broadcast_min_functor = | ||
BroadcastMinElementWiseDirectCUDAFunctor<T>(); | ||
if (x.dims().size() < y.dims().size() && | ||
x.dims().size() < weight.dims().size()) { | ||
broadcast_min_inputs.emplace_back(&x); | ||
phi::funcs::BroadcastKernel<T>(ctx, | ||
broadcast_min_inputs, | ||
&broadcast_min_outputs, | ||
broadcast_min_functor); | ||
inputs.emplace_back(&b_min); | ||
inputs.emplace_back(&y); | ||
inputs.emplace_back(&weight); | ||
} else if (y.dims().size() < weight.dims().size()) { | ||
broadcast_min_inputs.emplace_back(&y); | ||
phi::funcs::BroadcastKernel<T>(ctx, | ||
broadcast_min_inputs, | ||
&broadcast_min_outputs, | ||
broadcast_min_functor); | ||
inputs.emplace_back(&x); | ||
inputs.emplace_back(&b_min); | ||
inputs.emplace_back(&weight); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这部分的计算逻辑我理解是对输入的数据首先将维度按照 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 老师您好!非常感谢您的指点! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 从PM同学那里听说你对我的这部分修改建议持否定态度,请问下理由是什么吗?如果理由OK的话,我这边会合入的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @JamesLim-sy 老师您好!您误会了,我不是持否定的态度哈。我是遇到了自己难以解决的困难,向您寻求一下进一步的指导。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我的意思是,Paddle的Broadcast计算,支持 (input_0.broadcast + input_1.broadcast + input_2.broadcast) = (output_0, output_1) 这种计算模式,不必先单独broaddcast::kUnary ,再执行计算的。可以本地先测试下通用一次性的BoradcastTenery 完成计算. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @JamesLim-sy 您的意思,是用一次多输出的Broadcast::kTernary替换掉一次单输出的Broadcast:::kUnary+一次单输出的Broadcast::kTernary吗?如果是这样的话,我查看源码之后发现并不可行。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @JamesLim-sy 老师您好,麻烦您再看一下。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @JamesLim-sy 呜呜呜,等好久了,您抽空再审核下吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 修改完毕 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @JamesLim-sy mingshu老师有时间review一下吗? |
||
} else { | ||
broadcast_min_inputs.emplace_back(&weight); | ||
phi::funcs::BroadcastKernel<T>(ctx, | ||
broadcast_min_inputs, | ||
&broadcast_min_outputs, | ||
broadcast_min_functor); | ||
inputs.emplace_back(&x); | ||
inputs.emplace_back(&y); | ||
inputs.emplace_back(&b_min); | ||
} | ||
} else { | ||
inputs.emplace_back(&x); | ||
inputs.emplace_back(&y); | ||
inputs.emplace_back(&weight); | ||
} | ||
phi::funcs::BroadcastKernel<T>(ctx, inputs, &outputs, functor); | ||
} | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(lerp, | ||
GPU, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这几个头文件都裹在
#include "paddle/phi/kernels/funcs/broadcast_function.h"
里面了,之后希望能再提一个PR修改掉.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done