-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optimize elementwise_mul_grad using new interfaces #37728
Changes from 29 commits
8f532b0
5b5804d
cee2470
5be3a45
a1d92b7
e674a5d
855d00b
7cb2c97
db9fc91
c7b68c8
0fd630e
4bbb33b
30a1a89
ae1d4ba
e06dd3c
7bee4f5
30c53ff
363daf2
56b3ec6
e1e6ef4
0965326
2c084e8
e41f374
25f91d3
fec5b3e
9b1507e
5cf6b5d
ef592c1
c7a6037
99a9324
2951afc
8fdd6d1
7b663f4
8b19aaf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ limitations under the License. */ | |
|
||
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" | ||
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" | ||
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" | ||
#include "paddle/fluid/platform/complex.h" | ||
#include "paddle/fluid/platform/float16.h" | ||
|
||
|
@@ -113,6 +114,181 @@ __global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex<double>>( | |
} | ||
} | ||
|
||
template <typename InT, typename OutT> | ||
struct MulGradXYFunctor { | ||
inline HOSTDEVICE paddle::framework::Array<OutT, 2> operator()(const InT& a, | ||
const InT& b, | ||
const InT& c) { | ||
paddle::framework::Array<OutT, 2> outs; | ||
// dx = dout * y | ||
// dy = dout * x | ||
outs[0] = a * b; | ||
outs[1] = a * c; | ||
return outs; | ||
} | ||
}; | ||
|
||
template <typename T> | ||
using complex = paddle::platform::complex<T>; | ||
|
||
template <typename InT, typename OutT> | ||
struct MulGradXYFunctor<complex<InT>, complex<OutT>> { | ||
inline HOSTDEVICE paddle::framework::Array<complex<OutT>, 2> operator()( | ||
const complex<InT>& a, const complex<InT>& b, const complex<InT>& c) { | ||
paddle::framework::Array<complex<OutT>, 2> outs; | ||
// dx = dout * y | ||
// dy = dout * x | ||
complex<InT> b_conj(b.real, -b.imag); | ||
complex<InT> c_conj(c.real, -c.imag); | ||
outs[0] = a * b_conj; | ||
outs[1] = a * c_conj; | ||
return outs; | ||
} | ||
}; | ||
|
||
template <typename T> | ||
void ReduceWrapper(const platform::CUDADeviceContext& dev_ctx, int axis, | ||
const framework::Tensor* in, const framework::Tensor* out, | ||
framework::Tensor* src, framework::Tensor* dst) { | ||
std::vector<int> reduce_dims = GetReduceDim(in->dims(), out->dims(), axis); | ||
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( | ||
*src, dst, kps::IdentityFunctor<T>(), reduce_dims, dev_ctx.stream()); | ||
} | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type | ||
default_elementwise_mul_grad(const framework::ExecutionContext& ctx, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这部分的代码根据Zjq9409的最新合入PR修改一下 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done。 |
||
const framework::Tensor* x, | ||
const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, | ||
framework::Tensor* dx, framework::Tensor* dy) { | ||
int axis = ctx.Attr<int>("axis"); | ||
const auto& dev_ctx = | ||
ctx.template device_context<platform::CUDADeviceContext>(); | ||
framework::Tensor tmp_dx; | ||
framework::Tensor tmp_dy; | ||
tmp_dx.mutable_data<T>(dout->dims(), ctx.GetPlace()); | ||
tmp_dy.mutable_data<T>(dout->dims(), ctx.GetPlace()); | ||
|
||
if (dx != nullptr && dy != nullptr) { | ||
dx->mutable_data<T>(ctx.GetPlace()); | ||
dy->mutable_data<T>(ctx.GetPlace()); | ||
std::vector<const framework::Tensor*> ins = {dout, y, x}; | ||
std::vector<framework::Tensor*> outs; | ||
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) { | ||
outs = {dx, dy}; | ||
} else if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) { | ||
outs = {&tmp_dx, dy}; | ||
} else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) { | ||
outs = {dx, &tmp_dy}; | ||
} else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) { | ||
outs = {&tmp_dx, &tmp_dy}; | ||
} | ||
auto functor = MulGradXYFunctor<T, T>(); | ||
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T, | ||
decltype(functor), 2>(dev_ctx, ins, &outs, axis, | ||
functor); | ||
if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) { | ||
ReduceWrapper<T>(dev_ctx, axis, x, out, &tmp_dx, dx); | ||
} else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) { | ||
ReduceWrapper<T>(dev_ctx, axis, y, out, &tmp_dy, dy); | ||
} else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) { | ||
ReduceWrapper<T>(dev_ctx, axis, x, out, &tmp_dx, dx); | ||
ReduceWrapper<T>(dev_ctx, axis, y, out, &tmp_dy, dy); | ||
} | ||
|
||
} else if (dx != nullptr && dy == nullptr) { | ||
dx->mutable_data<T>(ctx.GetPlace()); | ||
std::vector<const framework::Tensor*> ins = {dout, y}; | ||
std::vector<framework::Tensor*> outs; | ||
if (dx->dims() != dout->dims()) { | ||
outs = {&tmp_dx}; | ||
} else { | ||
outs = {dx}; | ||
} | ||
|
||
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( | ||
dev_ctx, ins, &outs, axis, MulGradFunctor<T>()); | ||
if (dx->dims() != dout->dims()) { | ||
ReduceWrapper<T>(dev_ctx, axis, x, out, &tmp_dx, dx); | ||
} | ||
} else if (dx == nullptr && dy != nullptr) { | ||
dy->mutable_data<T>(ctx.GetPlace()); | ||
std::vector<const framework::Tensor*> ins = {dout, x}; | ||
std::vector<framework::Tensor*> outs; | ||
if (dy->dims() != dout->dims()) { | ||
outs = {&tmp_dy}; | ||
} else { | ||
outs = {dy}; | ||
} | ||
|
||
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( | ||
dev_ctx, ins, &outs, axis, MulGradFunctor<T>()); | ||
if (dy->dims() != dout->dims()) { | ||
ReduceWrapper<T>(dev_ctx, axis, y, out, &tmp_dy, dy); | ||
} | ||
} | ||
} | ||
|
||
/* | ||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type | ||
default_elementwise_mul_grad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, | ||
const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, | ||
framework::Tensor* dx, framework::Tensor* dy) { | ||
int axis = ctx.Attr<int>("axis"); | ||
// dx | ||
if (dx != nullptr) { | ||
if (dx->dims() == dout->dims()) { | ||
// dx = dout * y | ||
ElementwiseComputeEx<MulGradFunctor<T>, DeviceContext, T>( | ||
ctx, dout, y, axis, MulGradFunctor<T>(), dx); | ||
|
||
} else { | ||
// For inplace strategy, dx will be stored in addr of dout, which makes | ||
// the result of dy wrong. | ||
if (dx->IsSharedBufferWith(*dout)) { | ||
dx->clear(); | ||
dx->mutable_data<T>(x->dims(), ctx.GetPlace()); | ||
} | ||
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis); | ||
gpuStream_t stream = ctx.cuda_device_context().stream(); | ||
|
||
framework::Tensor dx_tmp; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议修改命名方式 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. 改为dx_origin_dims,表示reduce之前的dx结果。 |
||
dx_tmp.Resize(dout->dims()); | ||
ElementwiseComputeEx<MulGradFunctor<T>, DeviceContext, T>( | ||
ctx, dout, y, axis, MulGradFunctor<T>(), &dx_tmp); | ||
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( | ||
dx_tmp, dx, kps::IdentityFunctor<T>(), reduce_dims, stream); | ||
} | ||
} | ||
// dy | ||
if (dy != nullptr) { | ||
if (dy->dims() == dout->dims()) { | ||
// dy = dout * x | ||
ElementwiseComputeEx<MulGradFunctor<T>, DeviceContext, T>( | ||
ctx, dout, x, axis, MulGradFunctor<T>(), dy); | ||
} else { | ||
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis); | ||
gpuStream_t stream = ctx.cuda_device_context().stream(); | ||
|
||
framework::Tensor dy_tmp; | ||
dy_tmp.Resize(dout->dims()); | ||
ElementwiseComputeEx<MulGradFunctor<T>, DeviceContext, T>( | ||
ctx, dout, x, axis, MulGradFunctor<T>(), &dy_tmp); | ||
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( | ||
dy_tmp, dy, kps::IdentityFunctor<T>(), reduce_dims, stream); | ||
} | ||
} | ||
} | ||
*/ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 删除无效的注释 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -170,6 +170,20 @@ struct MulGradDY<paddle::platform::complex<T>> { | |
} | ||
}; | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type | ||
default_elementwise_mul_grad(const framework::ExecutionContext& ctx, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. default_elementwise_mul_grad和elementwise_mul_grad代码存在重复 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done。 |
||
const framework::Tensor* x, | ||
const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, | ||
framework::Tensor* dx, framework::Tensor* dy) { | ||
int axis = ctx.Attr<int>("axis"); | ||
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>( | ||
ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(), MulGradDY<T>()); | ||
} | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type | ||
|
@@ -178,9 +192,7 @@ elementwise_mul_grad(const framework::ExecutionContext& ctx, | |
const framework::Tensor* out, | ||
const framework::Tensor* dout, framework::Tensor* dx, | ||
framework::Tensor* dy) { | ||
int axis = ctx.Attr<int>("axis"); | ||
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>( | ||
ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(), MulGradDY<T>()); | ||
default_elementwise_mul_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); | ||
} | ||
|
||
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
|
@@ -193,6 +205,16 @@ elementwise_mul_grad(const framework::ExecutionContext& ctx, | |
const framework::Tensor* out, | ||
const framework::Tensor* dout, framework::Tensor* dx, | ||
framework::Tensor* dy); | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type | ||
default_elementwise_mul_grad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, | ||
const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, | ||
framework::Tensor* dx, framework::Tensor* dy); | ||
#endif | ||
|
||
template <typename DeviceContext, typename T> | ||
|
@@ -208,13 +230,12 @@ class ElementwiseMulGradKernel : public ElemwiseGradKernel<T> { | |
auto* out = dout; // out is not necessary | ||
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); | ||
int axis = ctx.Attr<int>("axis"); | ||
|
||
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { | ||
elementwise_mul_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); | ||
} else { | ||
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>( | ||
ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(), | ||
MulGradDY<T>()); | ||
default_elementwise_mul_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, | ||
dy); | ||
} | ||
} | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SimpleElemwiseMulGradCUDAKernel函数代码可以删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done