Skip to content

Commit

Permalink
[PTen]elementwise_sub kernel refactor (#37260)
Browse files Browse the repository at this point in the history
* elementwise_add kernel refactor

* fix compile bugs in elementwise_add refactor

* fix compile bugs when run in npu/xpu

* fix bugs when run unit test

* fix bugs when run ci-windows

* modify code as recommended

* code format adjust

* fix bugs when run ci

* fix compile bug when run in ci-windwos

* elementwise_sub refactor

* add PD_DLL_DECL for elementwise_sub

* fix bugs when compilei
  • Loading branch information
YuanRisheng authored Nov 18, 2021
1 parent 706a789 commit 36a9565
Show file tree
Hide file tree
Showing 16 changed files with 313 additions and 57 deletions.
6 changes: 6 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ class ElementwiseOp : public framework::OperatorWithKernel {
{"axis"}, {"Out"});
}
}
if (Type() == "elementwise_sub") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
return framework::KernelSignature("elementwise_sub", {"X", "Y"},
{"axis"}, {"Out"});
}
}
return framework::KernelSignature("None", {"X"}, {}, {"Out"});
}
};
Expand Down
27 changes: 0 additions & 27 deletions paddle/fluid/operators/elementwise/elementwise_sub_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,33 +41,6 @@ struct CPUPlace;
namespace paddle {
namespace operators {

template <typename T>
struct SameDimsElemwiseSub<
platform::CPUDeviceContext, T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
blas.VSUB(x->numel(), x->data<T>(), y->data<T>(), z->data<T>());
}
};

template <typename T>
struct SameDimsElemwiseSub<
platform::CPUDeviceContext, T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
auto eigen_x = framework::EigenVector<T>::Flatten(*x);
auto eigen_y = framework::EigenVector<T>::Flatten(*y);
auto eigen_z = framework::EigenVector<T>::Flatten(*z);
auto &place = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device();
eigen_z.device(place) = eigen_x - eigen_y;
}
};
class ElementwiseSubOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Sub"; }
Expand Down
16 changes: 0 additions & 16 deletions paddle/fluid/operators/elementwise/elementwise_sub_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,6 @@ namespace plat = paddle::platform;
namespace paddle {
namespace operators {

template <typename T>
class ElementwiseSubKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, SubFunctor<T>());
}
};

template <typename T>
static __global__ void SimpleElemwiseSubGradCUDAKernel(const T* dout,
int64_t size, T* dx,
Expand Down
26 changes: 12 additions & 14 deletions paddle/fluid/operators/elementwise/elementwise_sub_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ limitations under the License. */

#pragma once

#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"

// only can include the headers in paddle/pten/include dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/math.h"
namespace paddle {
namespace operators {

Expand All @@ -37,13 +42,6 @@ void default_elementwise_sub(const framework::ExecutionContext& ctx,
}
}

template <typename DeviceContext, typename T, class Enable = void>
struct SameDimsElemwiseSub {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z);
};

template <typename DeviceContext, typename T>
class ElementwiseSubKernel : public framework::OpKernel<T> {
public:
Expand All @@ -53,13 +51,13 @@ class ElementwiseSubKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());

auto dims_equal = x->dims() == y->dims();
if (dims_equal) {
SameDimsElemwiseSub<DeviceContext, T> same_dims_sub;
same_dims_sub(ctx, x, y, z);
} else {
default_elementwise_sub<DeviceContext, T>(ctx, x, y, z);
}
auto& dev_ctx = ctx.device_context<DeviceContext>();
int axis = ctx.Attr<int>("axis");
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
pten::ElementwiseSub<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis,
pt_z.get());
}
};

Expand Down
1 change: 1 addition & 0 deletions paddle/pten/api/include/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ PD_DLL_DECL Tensor mean(const Tensor& x);

PD_DLL_DECL Tensor add(const Tensor& x, const Tensor& y);

PD_DLL_DECL Tensor subtract(const Tensor& x, const Tensor& y);
} // namespace experimental
} // namespace paddle
34 changes: 34 additions & 0 deletions paddle/pten/api/lib/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,40 @@ PD_DLL_DECL Tensor add(const Tensor& x, const Tensor& y) {
return out;
}

PD_DLL_DECL Tensor subtract(const Tensor& x, const Tensor& y) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"elementwise_sub", kernel_key);

// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(dev_ctx);

// 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(dense_x);
auto dense_y = std::dynamic_pointer_cast<pten::DenseTensor>(y.impl());
kernel_context.EmplaceBackInput(dense_y);
kernel_context.EmplaceBackAttr(-1);

// 4. InferShape
auto out_meta = ElementwiseInferShape(dense_x->meta(), dense_y->meta(), -1);

// 5. Prepare outputs
Tensor out;
const auto allocator = std::make_shared<DefaultAllocator>(
pten::TransToFluidPlace(kernel_key.backend()));
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
kernel_context.EmplaceBackOutput(dense_out);
out.set_impl(dense_out);

// 6. Call kernel
kernel(&kernel_context);

return out;
}
} // namespace experimental
} // namespace paddle

Expand Down
15 changes: 15 additions & 0 deletions paddle/pten/include/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,19 @@ DenseTensor ElementwiseAdd(const ContextT& dev_ctx,
ElementwiseAdd<T>(dev_ctx, x, y, axis, &dense_out);
return dense_out;
}

template <typename T, typename ContextT>
DenseTensor Subtract(const ContextT& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis) {
auto out_meta = ElementwiseInferShape(x.meta(), y.meta(), axis);
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace());
pten::DenseTensor dense_out(allocator, out_meta);
ElementwiseSub<T>(dev_ctx, x, y, axis, &dense_out);
return dense_out;
}

} // namespace pten
33 changes: 33 additions & 0 deletions paddle/pten/kernels/cpu/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,29 @@ void ElementwiseAdd(const CPUContext& dev_ctx,
}
}
}

template <typename T>
void ElementwiseSub(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
if (x.dims() == y.dims()) {
SameDimsElementwiseCompute<general::SameDimsSubFunctor<CPUContext, T>>()(
dev_ctx, x, y, out);
} else {
auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseCompute<general::SubFunctor<T>, T>(
dev_ctx, x, y, axis, general::SubFunctor<T>(), out);
} else {
ElementwiseCompute<general::InverseSubFunctor<T>, T>(
dev_ctx, x, y, axis, general::InverseSubFunctor<T>(), out);
}
}
}

} // namespace pten

// TODO(chenweihang): replace by better impl
Expand Down Expand Up @@ -135,3 +158,13 @@ PT_REGISTER_KERNEL("elementwise_add",
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL("elementwise_sub",
CPU,
ANY,
pten::ElementwiseSub,
float,
double,
int,
int64_t,
complex64,
complex128) {}
7 changes: 7 additions & 0 deletions paddle/pten/kernels/cpu/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,11 @@ void ElementwiseAdd(const CPUContext& dev_ctx,
int axis,
DenseTensor* out);

template <typename T>
void ElementwiseSub(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);

} // namespace pten
26 changes: 26 additions & 0 deletions paddle/pten/kernels/cuda/math.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,21 @@ void ElementwiseAdd(const CUDAContext& dev_ctx,
dev_ctx, inputs, &outputs, axis, general::AddFunctor<T>());
}

template <typename T>
void ElementwiseSub(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
inputs.emplace_back(&y);
outputs.emplace_back(out);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, inputs, &outputs, axis, general::SubFunctor<T>());
}

} // namespace pten

// TODO(chenweihang): replace by better impl
Expand Down Expand Up @@ -187,3 +202,14 @@ PT_REGISTER_KERNEL("elementwise_add",
float16,
complex64,
complex128) {}
PT_REGISTER_KERNEL("elementwise_sub",
CUDA,
ANY,
pten::ElementwiseSub,
float,
double,
int,
int64_t,
float16,
complex64,
complex128) {}
7 changes: 7 additions & 0 deletions paddle/pten/kernels/cuda/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ void ElementwiseAdd(const CUDAContext& dev_ctx,
int axis,
DenseTensor* out);

template <typename T>
void ElementwiseSub(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);

} // namespace pten

#endif
9 changes: 9 additions & 0 deletions paddle/pten/kernels/functions/blas/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,14 @@ void ElementwiseAdd(const DevCtx& dev_ctx,
blas.VADD(x.numel(), x.data<T>(), y.data<T>(), out->mutable_data<T>());
}

template <typename DevCtx, typename T>
void ElementwiseSub(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto blas = paddle::operators::math::GetBlas<DevCtx, T>(dev_ctx);
blas.VSUB(x.numel(), x.data<T>(), y.data<T>(), out->mutable_data<T>());
}

} // namespace blas
} // namespace pten
12 changes: 12 additions & 0 deletions paddle/pten/kernels/functions/eigen/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,17 @@ void ElementwiseAdd(const DevCtx& dev_ctx,
eigen_z.device(place) = eigen_x + eigen_y;
}

template <typename DevCtx, typename T>
void ElementwiseSub(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto eigen_x = pten::EigenVector<T>::Flatten(x);
auto eigen_y = pten::EigenVector<T>::Flatten(y);
auto eigen_z = pten::EigenVector<T>::Flatten(*out);
auto& place = *dev_ctx.eigen_device();
eigen_z.device(place) = eigen_x - eigen_y;
}

} // namespace eigen
} // namespace pten
44 changes: 44 additions & 0 deletions paddle/pten/kernels/functions/general/elementwise_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,49 @@ struct InverseAddFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b + a; }
};

// Subtract
template <typename DevCtx, typename T, class Enable = void>
struct SameDimsSubFunctor {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z);
};

template <typename DevCtx, typename T>
struct SameDimsSubFunctor<
DevCtx,
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
blas::ElementwiseSub<DevCtx, T>(dev_ctx, x, y, z);
}
};

template <typename DevCtx, typename T>
struct SameDimsSubFunctor<
DevCtx,
T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
eigen::ElementwiseSub<DevCtx, T>(dev_ctx, x, y, z);
}
};

template <typename T>
struct SubFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a - b; }
};
template <typename T>
struct InverseSubFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b - a; }
};

} // namespace general
} // namespace pten
Loading

0 comments on commit 36a9565

Please sign in to comment.