Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize elementwise_mul_grad using new interfaces #37728

Merged
merged 34 commits into from
Jan 5, 2022
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
8f532b0
Merge pull request #1 from PaddlePaddle/develop
AshburnLee Sep 8, 2020
5b5804d
Merge pull request #2 from PaddlePaddle/develop
AshburnLee Sep 17, 2020
cee2470
Merge pull request #3 from PaddlePaddle/develop
AshburnLee Sep 30, 2020
5be3a45
Merge pull request #4 from PaddlePaddle/develop
AshburnLee Oct 13, 2020
a1d92b7
Merge pull request #5 from PaddlePaddle/develop
AshburnLee Oct 20, 2020
e674a5d
Merge pull request #6 from PaddlePaddle/develop
AshburnLee Nov 15, 2020
855d00b
Merge pull request #7 from PaddlePaddle/develop
AshburnLee Nov 18, 2020
7cb2c97
Merge pull request #8 from PaddlePaddle/develop
AshburnLee Mar 31, 2021
db9fc91
Merge pull request #9 from PaddlePaddle/develop
AshburnLee Apr 7, 2021
c7b68c8
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Apr 26, 2021
0fd630e
Merge branch 'PaddlePaddle:develop' into develop
AshburnLee Aug 16, 2021
4bbb33b
Merge branch 'PaddlePaddle:develop' into develop
AshburnLee Sep 28, 2021
30a1a89
Merge branch 'PaddlePaddle:develop' into develop
AshburnLee Nov 22, 2021
ae1d4ba
init commit: new elem_mul_grad
AshburnLee Nov 30, 2021
e06dd3c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Nov 30, 2021
7bee4f5
add template speciallization for complex in multiply
AshburnLee Dec 1, 2021
30c53ff
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Dec 1, 2021
363daf2
reply review comments
AshburnLee Dec 3, 2021
56b3ec6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Dec 3, 2021
e1e6ef4
correct dx and dy computation when T is complex
AshburnLee Dec 6, 2021
0965326
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Dec 6, 2021
2c084e8
reply review comments
AshburnLee Dec 13, 2021
e41f374
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Dec 13, 2021
25f91d3
update to new ReduceRunctor
AshburnLee Dec 20, 2021
fec5b3e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Dec 20, 2021
9b1507e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Dec 20, 2021
5cf6b5d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Dec 29, 2021
ef592c1
mul-output broadcast
AshburnLee Dec 29, 2021
c7a6037
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Dec 29, 2021
99a9324
call functions
AshburnLee Jan 5, 2022
2951afc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 5, 2022
8fdd6d1
call functions with comments
AshburnLee Jan 5, 2022
7b663f4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 5, 2022
8b19aaf
remove comments
AshburnLee Jan 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#pragma once

#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/hostdevice.h"
Expand Down Expand Up @@ -113,6 +114,19 @@ struct MinFunctor {
}
};

template <typename T>
struct MulGradFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; }
};
template <typename T>
struct MulGradFunctor<paddle::platform::complex<T>> {
inline HOSTDEVICE paddle::platform::complex<T> operator()(
const paddle::platform::complex<T>& a,
const paddle::platform::complex<T>& b) const {
paddle::platform::complex<T> b_conj(b.real, -b.imag);
return a * b_conj;
}
};
// Fmax
template <typename T>
struct FMaxFunctor {
Expand Down
176 changes: 176 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_mul_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"

Expand Down Expand Up @@ -113,6 +114,181 @@ __global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex<double>>(
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SimpleElemwiseMulGradCUDAKernel函数代码可以删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

template <typename InT, typename OutT>
struct MulGradXYFunctor {
inline HOSTDEVICE paddle::framework::Array<OutT, 2> operator()(const InT& a,
const InT& b,
const InT& c) {
paddle::framework::Array<OutT, 2> outs;
// dx = dout * y
// dy = dout * x
outs[0] = a * b;
outs[1] = a * c;
return outs;
}
};

template <typename T>
using complex = paddle::platform::complex<T>;

template <typename InT, typename OutT>
struct MulGradXYFunctor<complex<InT>, complex<OutT>> {
inline HOSTDEVICE paddle::framework::Array<complex<OutT>, 2> operator()(
const complex<InT>& a, const complex<InT>& b, const complex<InT>& c) {
paddle::framework::Array<complex<OutT>, 2> outs;
// dx = dout * y
// dy = dout * x
complex<InT> b_conj(b.real, -b.imag);
complex<InT> c_conj(c.real, -c.imag);
outs[0] = a * b_conj;
outs[1] = a * c_conj;
return outs;
}
};

template <typename T>
void ReduceWrapper(const platform::CUDADeviceContext& dev_ctx, int axis,
const framework::Tensor* in, const framework::Tensor* out,
framework::Tensor* src, framework::Tensor* dst) {
std::vector<int> reduce_dims = GetReduceDim(in->dims(), out->dims(), axis);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*src, dst, kps::IdentityFunctor<T>(), reduce_dims, dev_ctx.stream());
}

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_mul_grad(const framework::ExecutionContext& ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分的代码根据Zjq9409的最新合入PR修改一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done。

const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
framework::Tensor tmp_dx;
framework::Tensor tmp_dy;
tmp_dx.mutable_data<T>(dout->dims(), ctx.GetPlace());
tmp_dy.mutable_data<T>(dout->dims(), ctx.GetPlace());

if (dx != nullptr && dy != nullptr) {
dx->mutable_data<T>(ctx.GetPlace());
dy->mutable_data<T>(ctx.GetPlace());
std::vector<const framework::Tensor*> ins = {dout, y, x};
std::vector<framework::Tensor*> outs;
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) {
outs = {dx, dy};
} else if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) {
outs = {&tmp_dx, dy};
} else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) {
outs = {dx, &tmp_dy};
} else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) {
outs = {&tmp_dx, &tmp_dy};
}
auto functor = MulGradXYFunctor<T, T>();
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T,
decltype(functor), 2>(dev_ctx, ins, &outs, axis,
functor);
if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) {
ReduceWrapper<T>(dev_ctx, axis, x, out, &tmp_dx, dx);
} else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) {
ReduceWrapper<T>(dev_ctx, axis, y, out, &tmp_dy, dy);
} else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) {
ReduceWrapper<T>(dev_ctx, axis, x, out, &tmp_dx, dx);
ReduceWrapper<T>(dev_ctx, axis, y, out, &tmp_dy, dy);
}

} else if (dx != nullptr && dy == nullptr) {
dx->mutable_data<T>(ctx.GetPlace());
std::vector<const framework::Tensor*> ins = {dout, y};
std::vector<framework::Tensor*> outs;
if (dx->dims() != dout->dims()) {
outs = {&tmp_dx};
} else {
outs = {dx};
}

LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, axis, MulGradFunctor<T>());
if (dx->dims() != dout->dims()) {
ReduceWrapper<T>(dev_ctx, axis, x, out, &tmp_dx, dx);
}
} else if (dx == nullptr && dy != nullptr) {
dy->mutable_data<T>(ctx.GetPlace());
std::vector<const framework::Tensor*> ins = {dout, x};
std::vector<framework::Tensor*> outs;
if (dy->dims() != dout->dims()) {
outs = {&tmp_dy};
} else {
outs = {dy};
}

LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, axis, MulGradFunctor<T>());
if (dy->dims() != dout->dims()) {
ReduceWrapper<T>(dev_ctx, axis, y, out, &tmp_dy, dy);
}
}
}

/*
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_mul_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
// dx
if (dx != nullptr) {
if (dx->dims() == dout->dims()) {
// dx = dout * y
ElementwiseComputeEx<MulGradFunctor<T>, DeviceContext, T>(
ctx, dout, y, axis, MulGradFunctor<T>(), dx);

} else {
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), ctx.GetPlace());
}
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();

framework::Tensor dx_tmp;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议修改命名方式

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. 改为dx_origin_dims,表示reduce之前的dx结果。

dx_tmp.Resize(dout->dims());
ElementwiseComputeEx<MulGradFunctor<T>, DeviceContext, T>(
ctx, dout, y, axis, MulGradFunctor<T>(), &dx_tmp);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dx_tmp, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
// dy
if (dy != nullptr) {
if (dy->dims() == dout->dims()) {
// dy = dout * x
ElementwiseComputeEx<MulGradFunctor<T>, DeviceContext, T>(
ctx, dout, x, axis, MulGradFunctor<T>(), dy);
} else {
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();

framework::Tensor dy_tmp;
dy_tmp.Resize(dout->dims());
ElementwiseComputeEx<MulGradFunctor<T>, DeviceContext, T>(
ctx, dout, x, axis, MulGradFunctor<T>(), &dy_tmp);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dy_tmp, dy, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
}
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除无效的注释

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
Expand Down
35 changes: 28 additions & 7 deletions paddle/fluid/operators/elementwise/elementwise_mul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,20 @@ struct MulGradDY<paddle::platform::complex<T>> {
}
};

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
default_elementwise_mul_grad(const framework::ExecutionContext& ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default_elementwise_mul_grad和elementwise_mul_grad代码存在重复

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done。

const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(), MulGradDY<T>());
}

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
Expand All @@ -178,9 +192,7 @@ elementwise_mul_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(), MulGradDY<T>());
default_elementwise_mul_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Expand All @@ -193,6 +205,16 @@ elementwise_mul_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy);

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_mul_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy);
#endif

template <typename DeviceContext, typename T>
Expand All @@ -208,13 +230,12 @@ class ElementwiseMulGradKernel : public ElemwiseGradKernel<T> {
auto* out = dout; // out is not necessary
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");

if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_mul_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else {
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(),
MulGradDY<T>());
default_elementwise_mul_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
dy);
}
}
};
Expand Down