Skip to content

Commit

Permalink
[X86][XPU] add reduce_max; fix xpu fill_any_like; test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang committed Mar 5, 2021
1 parent acd40c8 commit c09a01b
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 9 deletions.
10 changes: 10 additions & 0 deletions lite/kernels/x86/reduce_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,13 @@ REGISTER_LITE_KERNEL(reduce_mean,
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();

REGISTER_LITE_KERNEL(reduce_max,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::ReduceMaxCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
45 changes: 45 additions & 0 deletions lite/kernels/x86/reduce_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ struct MeanFunctor {
}
};

struct MaxFunctor {
template <typename X, typename Y, typename Dim>
void operator()(X* x, Y* y, const Dim& dim) {
y->device(lite::fluid::EigenDeviceType<TARGET(kX86)>()) = x->maximum(dim);
}
};

#define HANDLE_DIM(NDIM, RDIM, FUNCTOR) \
if (ndim == NDIM && rdim == RDIM) { \
paddle::lite::kernels::x86:: \
Expand Down Expand Up @@ -120,6 +127,44 @@ class ReduceMeanCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
virtual ~ReduceMeanCompute() = default;
};

template <typename T>
class ReduceMaxCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ReduceParam;

void Run() override {
auto& param = *param_.get_mutable<operators::ReduceParam>();
auto* input = param.X;
auto* Out = param.Out;
param.Out->template mutable_data<T>();

const auto& dims = param.dim;
bool keep_dim = param.keep_dim;

if (dims.size() == 0) {
// Flatten and reduce 1-D tensor
auto x = lite::fluid::EigenVector<T>::Flatten(*input);
auto out = lite::fluid::EigenScalar<T>::From(Out);
auto reduce_dim = Eigen::array<int, 1>({{0}});
MaxFunctor functor;
functor(&x, &out, reduce_dim);
} else {
int ndim = input->dims().size();
int rdim = dims.size();
HANDLE_DIM(4, 3, MaxFunctor);
HANDLE_DIM(4, 2, MaxFunctor);
HANDLE_DIM(4, 1, MaxFunctor);
HANDLE_DIM(3, 2, MaxFunctor);
HANDLE_DIM(3, 1, MaxFunctor);
HANDLE_DIM(2, 2, MaxFunctor);
HANDLE_DIM(2, 1, MaxFunctor);
HANDLE_DIM(1, 1, MaxFunctor);
}
}

virtual ~ReduceMaxCompute() = default;
};

} // namespace x86
} // namespace kernels
} // namespace lite
Expand Down
1 change: 1 addition & 0 deletions lite/kernels/xpu/fill_any_like_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ void FillAnyLikeCompute::Run() {
static_cast<int64_t>(param.value));
break;
}
case -1:
case 5: {
auto data = param.Out->mutable_data<float>(TARGET(kXPU));
r = xdnn::constant<float>(ctx.GetRawContext(),
Expand Down
19 changes: 10 additions & 9 deletions lite/tests/kernels/reduce_max_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ class ReduceMaxComputeTester : public arena::TestCase {
}
};

void test_reduce_max(Place place) {
void test_reduce_max_4d(Place place) {
std::vector<std::vector<int>> reduce_dim{
{0}, {1}, {2}, {3}, {0, 1}, {1, 2}, {2, 3}, {-2, -1}};
for (auto n : {1, 3}) {
Expand All @@ -421,7 +421,7 @@ void test_reduce_max(Place place) {
}
}

void test_reduce_max_for_three(Place place) {
void test_reduce_max_3d(Place place) {
std::vector<std::vector<int>> reduce_dim{{0}, {1}, {2}};
for (bool keep_dim : {false, true}) {
for (auto dim : reduce_dim) {
Expand All @@ -435,14 +435,15 @@ void test_reduce_max_for_three(Place place) {
}

TEST(ReduceMax, precision) {
// #ifdef LITE_WITH_X86
// Place place(TARGET(kX86));
// #endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_reduce_max(place);
test_reduce_max_for_three(place);
Place place;
#if defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_X86)
place = TARGET(kX86);
#endif

test_reduce_max_4d(place);
test_reduce_max_3d(place);
}

} // namespace lite
Expand Down

0 comments on commit c09a01b

Please sign in to comment.