-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optimize elementwise_mul_grad using new interfaces #37728
Changes from 19 commits
8f532b0
5b5804d
cee2470
5be3a45
a1d92b7
e674a5d
855d00b
7cb2c97
db9fc91
c7b68c8
0fd630e
4bbb33b
30a1a89
ae1d4ba
e06dd3c
7bee4f5
30c53ff
363daf2
56b3ec6
e1e6ef4
0965326
2c084e8
e41f374
25f91d3
fec5b3e
9b1507e
5cf6b5d
ef592c1
c7a6037
99a9324
2951afc
8fdd6d1
7b663f4
8b19aaf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ limitations under the License. */ | |
|
||
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" | ||
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" | ||
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" | ||
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" | ||
#include "paddle/fluid/platform/complex.h" | ||
#include "paddle/fluid/platform/float16.h" | ||
|
||
|
@@ -114,6 +116,54 @@ __global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex<double>>( | |
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SimpleElemwiseMulGradCUDAKernel函数代码可以删除 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
} | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type | ||
default_elementwise_mul_grad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, | ||
const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, | ||
framework::Tensor* dx, framework::Tensor* dy) { | ||
int axis = ctx.Attr<int>("axis"); | ||
// dx | ||
if (dx != nullptr) { | ||
if (dx->dims() == dout->dims()) { | ||
// dx = dout * y | ||
default_elementwise_mul<DeviceContext, T>(ctx, dout, y, dx); | ||
} else { | ||
// For inplace strategy, dx will be stored in addr of dout, which makes | ||
// the result of dy wrong. | ||
if (dx->IsSharedBufferWith(*dout)) { | ||
dx->clear(); | ||
dx->mutable_data<T>(x->dims(), ctx.GetPlace()); | ||
} | ||
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis); | ||
gpuStream_t stream = ctx.cuda_device_context().stream(); | ||
|
||
framework::Tensor dx_tmp; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议修改命名方式 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. 改为dx_origin_dims,表示reduce之前的dx结果。 |
||
dx_tmp.Resize(dout->dims()); | ||
default_elementwise_mul<DeviceContext, T>(ctx, dout, y, &dx_tmp); | ||
TensorReduceFunctorImpl<T, T, CustomSum>(dx_tmp, dx, reduce_dims, stream); | ||
} | ||
} | ||
// dy | ||
if (dy != nullptr) { | ||
if (dy->dims() == dout->dims()) { | ||
// dy = dout * x | ||
default_elementwise_mul<DeviceContext, T>(ctx, dout, x, dy); | ||
} else { | ||
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis); | ||
gpuStream_t stream = ctx.cuda_device_context().stream(); | ||
|
||
framework::Tensor dy_tmp; | ||
dy_tmp.Resize(dout->dims()); | ||
default_elementwise_mul<DeviceContext, T>(ctx, dout, x, &dy_tmp); | ||
TensorReduceFunctorImpl<T, T, CustomSum>(dy_tmp, dy, reduce_dims, stream); | ||
} | ||
} | ||
} | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -184,6 +184,20 @@ elementwise_mul_grad(const framework::ExecutionContext& ctx, | |
ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(), MulGradDY<T>()); | ||
} | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type | ||
default_elementwise_mul_grad(const framework::ExecutionContext& ctx, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. default_elementwise_mul_grad和elementwise_mul_grad代码存在重复 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done。 |
||
const framework::Tensor* x, | ||
const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, | ||
framework::Tensor* dx, framework::Tensor* dy) { | ||
int axis = ctx.Attr<int>("axis"); | ||
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>( | ||
ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(), MulGradDY<T>()); | ||
} | ||
|
||
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
// cuda definition | ||
template <typename DeviceContext, typename T> | ||
|
@@ -194,6 +208,16 @@ elementwise_mul_grad(const framework::ExecutionContext& ctx, | |
const framework::Tensor* out, | ||
const framework::Tensor* dout, framework::Tensor* dx, | ||
framework::Tensor* dy); | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type | ||
default_elementwise_mul_grad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, | ||
const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, | ||
framework::Tensor* dx, framework::Tensor* dy); | ||
#endif | ||
|
||
template <typename DeviceContext, typename T> | ||
|
@@ -209,13 +233,12 @@ class ElementwiseMulGradKernel : public ElemwiseGradKernel<T> { | |
auto* out = dout; // out is not necessary | ||
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); | ||
int axis = ctx.Attr<int>("axis"); | ||
|
||
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { | ||
elementwise_mul_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); | ||
} else { | ||
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>( | ||
ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(), | ||
MulGradDY<T>()); | ||
default_elementwise_mul_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, | ||
dy); | ||
} | ||
} | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数可以删除,可以判断当复数形式时从原来的y(y.real, y.imag)构造y_conj(y.real, -y.imag);传入乘法就行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此函数放在这里确实不合适,与MulFunctor语义冲突。已修改