Skip to content

Commit

Permalink
move paddle/operators/math/compound_functors.h
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfeiyu committed Feb 14, 2022
1 parent 8ae1aa9 commit fc086a0
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 58 deletions.
69 changes: 33 additions & 36 deletions paddle/fluid/operators/fused/fused_elemwise_activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/compound_functors.h"
#include "paddle/pten/kernels/funcs/compound_functors.h"
#include "paddle/pten/kernels/funcs/elementwise_functor.h"
#include "paddle/pten/kernels/funcs/functors.h"

Expand Down Expand Up @@ -54,22 +54,22 @@ static void RunBinaryCompoundFunctor(
// intermediate_out = Unary(Y)
// out = Binary(X, Unary(Y))
// In this case, the shape of intermediate_out and out are different.
paddle::operators::math::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>
pten::funcs::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>
compound_func(binary_functor, unary_functor);
int axis = ctx.Attr<int>("axis");
if (ctx.Attr<bool>("save_intermediate_out")) {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::BinaryCompoundFunctor<
T, BinaryFunctor, UnaryFunctor>,
true /*KeepIntermediateValue*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
FusedElemwiseAndActComputeEx<
DeviceContext, T,
pten::funcs::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>,
true /*KeepIntermediateValue*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
} else {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::BinaryCompoundFunctor<
T, BinaryFunctor, UnaryFunctor>,
false /*KeepIntermediateValue*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
FusedElemwiseAndActComputeEx<
DeviceContext, T,
pten::funcs::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>,
false /*KeepIntermediateValue*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
}
}
Expand All @@ -86,22 +86,22 @@ static void RunUnaryCompoundFunctors(
// In this case, the shape of intermediate_out and out are the same.
int axis = ctx.Attr<int>("axis");

paddle::operators::math::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>
pten::funcs::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>
compound_func(unary_functor, binary_functor);

if (ctx.Attr<bool>("save_intermediate_out")) {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::UnaryCompoundFunctor<
T, UnaryFunctor, BinaryFunctor>,
true /*KeepIntermediateValue*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
FusedElemwiseAndActComputeEx<
DeviceContext, T,
pten::funcs::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>,
true /*KeepIntermediateValue*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
} else {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::UnaryCompoundFunctor<
T, UnaryFunctor, BinaryFunctor>,
false /*KeepIntermediateValue*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
FusedElemwiseAndActComputeEx<
DeviceContext, T,
pten::funcs::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>,
false /*KeepIntermediateValue*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
}
}
Expand All @@ -121,13 +121,12 @@ static void RunBinaryCompoundGradFunctors(
int axis = ctx.Attr<int>("axis");

using BinaryCompoundDxFunctor =
paddle::operators::math::BinaryCompoundGradDxFunctor<T, BinaryGradFunctor,
UnaryFunctor>;
using BinaryCompoundDyFunctor =
paddle::operators::math::BinaryCompoundGradDyFunctor<
T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor, InPlace>;
pten::funcs::BinaryCompoundGradDxFunctor<T, BinaryGradFunctor,
UnaryFunctor>;
using BinaryCompoundDyFunctor = pten::funcs::BinaryCompoundGradDyFunctor<
T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor, InPlace>;
using BinaryCompoundDIntermedaiteOutFunctor =
paddle::operators::math::BinaryCompoundGradDIntermedaiteOutFunctor<
pten::funcs::BinaryCompoundGradDIntermedaiteOutFunctor<
T, BinaryGradFunctor, UnaryFunctor>;

if (in_intermediate_out) {
Expand Down Expand Up @@ -171,14 +170,12 @@ static void RunUnaryCompoundGradFunctors(
// Z = Unary(Binary(X, Y))
int axis = ctx.Attr<int>("axis");

using UnaryCompoundDxFunctor =
paddle::operators::math::UnaryCompoundGradDxFunctor<
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
using UnaryCompoundDyFunctor =
paddle::operators::math::UnaryCompoundGradDyFunctor<
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
using UnaryCompoundDxFunctor = pten::funcs::UnaryCompoundGradDxFunctor<
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
using UnaryCompoundDyFunctor = pten::funcs::UnaryCompoundGradDyFunctor<
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
using UnaryCompoundDIntermediateFunctor =
paddle::operators::math::UnaryCompoundGradDIntermediateFunctor<
pten::funcs::UnaryCompoundGradDIntermediateFunctor<
T, UnaryGradFunctor, BinaryFunctor, InPlace>;

if (in_intermediate_out) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ limitations under the License. */
#include <unordered_set>
#include <vector>

namespace paddle {
namespace operators {
namespace math {
namespace pten {
namespace funcs {

// Z = BinaryFunctor(X, UnaryFunctor(Y))
template <typename T, typename BinaryFunctor, typename UnaryFunctor>
Expand Down Expand Up @@ -69,8 +68,8 @@ struct BinaryCompoundGradDxFunctor {
return dout * d_binary_fun_.Dx(x, unary_fun_(y));
}

inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
T dout) {
inline HOSTDEVICE T
UseIntermediateOut(T x, T y, T intermediate_out, T out, T dout) {
return dout * d_binary_fun_.Dx(x, intermediate_out);
}

Expand All @@ -82,8 +81,11 @@ struct BinaryCompoundGradDxFunctor {
};

// Z = BinaryFunctor(X, UnaryFunctor(Y))
template <typename T, typename DBinaryFun, typename UnaryFun,
typename DUnaryFun, bool InPlace>
template <typename T,
typename DBinaryFun,
typename UnaryFun,
typename DUnaryFun,
bool InPlace>
struct BinaryCompoundGradDyFunctor {
BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun,
Expand All @@ -96,8 +98,8 @@ struct BinaryCompoundGradDyFunctor {
return dout * d_binary_fun_.Dy(x, unary_fun_(y)) * d_unary_fun_.UseX(y);
}

inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
T dout) {
inline HOSTDEVICE T
UseIntermediateOut(T x, T y, T intermediate_out, T out, T dout) {
if (InPlace) {
return dout * d_binary_fun_.Dy(x, intermediate_out) *
d_unary_fun_.UseOut(intermediate_out);
Expand All @@ -116,8 +118,11 @@ struct BinaryCompoundGradDyFunctor {
};

// Z = UnaryFunctor(BinaryFunctor(X, Y))
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool InPlace>
template <typename T,
typename DUnaryFun,
typename BinaryFun,
typename DBinaryFun,
bool InPlace>
struct UnaryCompoundGradDxFunctor {
UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
Expand All @@ -136,8 +141,8 @@ struct UnaryCompoundGradDxFunctor {
return base * d_binary_fun_.Dx(x, y);
}

inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
T dout) {
inline HOSTDEVICE T
UseIntermediateOut(T x, T y, T intermediate_out, T out, T dout) {
T base;
if (InPlace) {
base = dout * d_unary_fun_.UseOut(out);
Expand All @@ -156,8 +161,11 @@ struct UnaryCompoundGradDxFunctor {
};

// Z = UnaryFunctor(BinaryFunctor(X, Y))
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool InPlace>
template <typename T,
typename DUnaryFun,
typename BinaryFun,
typename DBinaryFun,
bool InPlace>
struct UnaryCompoundGradDyFunctor {
UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
Expand All @@ -176,8 +184,8 @@ struct UnaryCompoundGradDyFunctor {
return base * d_binary_fun_.Dy(x, y);
}

inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
T dout) {
inline HOSTDEVICE T
UseIntermediateOut(T x, T y, T intermediate_out, T out, T dout) {
T base;
if (InPlace) {
base = dout * d_unary_fun_.UseOut(out);
Expand Down Expand Up @@ -206,7 +214,9 @@ struct BinaryCompoundGradDIntermedaiteOutFunctor {
return dout * d_binary_fun_.Dy(x, unary_fun_(y));
}

inline HOSTDEVICE T UseIntermediateOut(T x, T intermediate_out, T out,
inline HOSTDEVICE T UseIntermediateOut(T x,
T intermediate_out,
T out,
T dout) {
return dout * d_binary_fun_.Dy(x, intermediate_out);
}
Expand All @@ -233,7 +243,9 @@ struct UnaryCompoundGradDIntermediateFunctor {
}
}

inline HOSTDEVICE T UseIntermediateOut(T x, T intermediate_out, T out,
inline HOSTDEVICE T UseIntermediateOut(T x,
T intermediate_out,
T out,
T dout) {
if (InPlace) {
return dout * d_unary_fun_.UseOut(out);
Expand All @@ -249,6 +261,5 @@ struct UnaryCompoundGradDIntermediateFunctor {
BinaryFun binary_fun_;
};

} // namespace math
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace pten

1 comment on commit fc086a0

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.