Skip to content

Commit

Permalink
[X86] slice support int32, int64; add elementwise_div (#5898)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang authored Apr 15, 2021
1 parent bf89e00 commit cd98943
Show file tree
Hide file tree
Showing 11 changed files with 281 additions and 733 deletions.
1 change: 0 additions & 1 deletion lite/kernels/x86/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ add_kernel(interpolate_compute_x86 X86 basic SRCS interpolate_compute.cc DEPS ${

lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86)
lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86)
lite_cc_test(test_slice_compute_x86 SRCS slice_compute_test.cc DEPS slice_compute_x86)
lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_x86)
lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86)
lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86)
Expand Down
33 changes: 33 additions & 0 deletions lite/kernels/x86/elementwise_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,39 @@ REGISTER_LITE_KERNEL(elementwise_mul,
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.Finalize();

REGISTER_LITE_KERNEL(elementwise_div,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::ElementwiseDivCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))})
.Finalize();

REGISTER_LITE_KERNEL(elementwise_div,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::ElementwiseDivCompute<int>,
int32)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.Finalize();

REGISTER_LITE_KERNEL(elementwise_div,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::ElementwiseDivCompute<int64_t>,
int64)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.Finalize();

REGISTER_LITE_KERNEL(
elementwise_floordiv,
kX86,
Expand Down
39 changes: 31 additions & 8 deletions lite/kernels/x86/elementwise_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ struct MulFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a * b; }
};

template <typename T>
struct DivFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a / b; }
};

template <typename T>
struct FloorDivFunctor {
inline HOSTDEVICE T operator()(T a, T b) const {
Expand All @@ -64,6 +69,24 @@ struct MinFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a < b ? a : b; }
};

template <typename T>
class ElementwiseAddCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
param.Out->template mutable_data<T>();
paddle::lite::kernels::x86::ElementwiseComputeEx<AddFunctor<T>,
lite::TargetType::kX86,
T>(
context, param.X, param.Y, param.axis, AddFunctor<T>(), param.Out);
}

virtual ~ElementwiseAddCompute() = default;
};

template <typename T>
class ElementwiseSubCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
Expand All @@ -85,39 +108,39 @@ class ElementwiseSubCompute
};

template <typename T>
class ElementwiseAddCompute
class ElementwiseMulCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
param.Out->template mutable_data<T>();
paddle::lite::kernels::x86::ElementwiseComputeEx<AddFunctor<T>,
paddle::lite::kernels::x86::ElementwiseComputeEx<MulFunctor<T>,
lite::TargetType::kX86,
T>(
context, param.X, param.Y, param.axis, AddFunctor<T>(), param.Out);
context, param.X, param.Y, param.axis, MulFunctor<T>(), param.Out);
}

virtual ~ElementwiseAddCompute() = default;
virtual ~ElementwiseMulCompute() = default;
};

template <typename T>
class ElementwiseMulCompute
class ElementwiseDivCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
param.Out->template mutable_data<T>();
paddle::lite::kernels::x86::ElementwiseComputeEx<MulFunctor<T>,
paddle::lite::kernels::x86::ElementwiseComputeEx<DivFunctor<T>,
lite::TargetType::kX86,
T>(
context, param.X, param.Y, param.axis, MulFunctor<T>(), param.Out);
context, param.X, param.Y, param.axis, DivFunctor<T>(), param.Out);
}

virtual ~ElementwiseMulCompute() = default;
virtual ~ElementwiseDivCompute() = default;
};

template <typename T>
Expand Down
48 changes: 38 additions & 10 deletions lite/kernels/x86/slice_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@ REGISTER_LITE_KERNEL(slice,
kNCHW,
paddle::lite::kernels::x86::SliceCompute<float>,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("StartsTensor", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("EndsTensor", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("StartsTensorList", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))})
.BindInput("StartsTensor",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kAny))})
.BindInput("EndsTensor",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kAny))})
.BindInput("StartsTensorList",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kAny))})
.BindInput("EndsTensorList",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))})
.Finalize();

REGISTER_LITE_KERNEL(slice,
Expand All @@ -36,9 +41,32 @@ REGISTER_LITE_KERNEL(slice,
int32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.BindInput("StartsTensor", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("EndsTensor", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("StartsTensorList", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("StartsTensor",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kAny))})
.BindInput("EndsTensor",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kAny))})
.BindInput("StartsTensorList",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kAny))})
.BindInput("EndsTensorList",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.Finalize();

REGISTER_LITE_KERNEL(slice,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SliceCompute<int64_t>,
int64)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.BindInput("StartsTensor",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kAny))})
.BindInput("EndsTensor",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kAny))})
.BindInput("StartsTensorList",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kAny))})
.BindInput("EndsTensorList",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.Finalize();
Loading

0 comments on commit cd98943

Please sign in to comment.