From 03120a375916b8f2d9ceabbcd76be54314caded7 Mon Sep 17 00:00:00 2001 From: newway <237745+newway@users.noreply.github.com> Date: Wed, 24 Mar 2021 09:28:47 +0800 Subject: [PATCH] [xpu] refine sequence_pool (#5781) --- lite/kernels/xpu/sequence_pool_compute.cc | 65 ++++++++++++------- .../kernels/sequence_pool_compute_test.cc | 10 ++- 2 files changed, 50 insertions(+), 25 deletions(-) diff --git a/lite/kernels/xpu/sequence_pool_compute.cc b/lite/kernels/xpu/sequence_pool_compute.cc index 2186ef65296..1e4cf332d06 100644 --- a/lite/kernels/xpu/sequence_pool_compute.cc +++ b/lite/kernels/xpu/sequence_pool_compute.cc @@ -34,22 +34,9 @@ void XPUSequencePoolCompute::Run() { auto* in = param.X; auto* out = param.Out; + float pad_value = param.pad_value; std::string pool_type_str = param.pool_type; - auto dims = in->dims(); - auto lod = in->lod(); - dims[0] = lod[0].size() - 1; - - xdnn::Pooling_t pool_type = xdnn::Pooling_t::MAX_WITHOUT_INDEX; - if (pool_type_str == "MAX") { - } else if (pool_type_str == "SUM") { - pool_type = xdnn::Pooling_t::SUM; - } else if (pool_type_str == "LAST") { - pool_type = xdnn::Pooling_t::LAST; - } else { - CHECK(false); - } - int num_seq = out->dims()[0]; int dim = out->numel() / num_seq; @@ -62,16 +49,48 @@ void XPUSequencePoolCompute::Run() { lod_cpu.get(), in_lod.size() * sizeof(int), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + int r = 0; + if (pool_type_str == "MAX") { + r = xdnn::sequence_max_pool( + ctx.GetRawContext(), + in->data(), + lod_xpu, + out->mutable_data(TARGET(kXPU)), + num_seq, + dim, + pad_value, + nullptr); + } else if (pool_type_str == "SUM") { + r = xdnn::sequence_sum_pool( + ctx.GetRawContext(), + in->data(), + lod_xpu, + out->mutable_data(TARGET(kXPU)), + num_seq, + dim, + pad_value); + } else if (pool_type_str == "LAST") { + r = xdnn::sequence_last_pool( + ctx.GetRawContext(), + in->data(), + lod_xpu, + out->mutable_data(TARGET(kXPU)), + num_seq, + dim, + pad_value); + } else if (pool_type_str == "FIRST") { + r = xdnn::sequence_first_pool( + ctx.GetRawContext(), + in->data(), + lod_xpu, + out->mutable_data(TARGET(kXPU)), + num_seq, + dim, + pad_value); + } else { + CHECK(false) << " unsupported pool_type_str: " << pool_type_str; + } - int r = - xdnn::sequence_pooling_forward(ctx.GetRawContext(), - pool_type, - num_seq, - lod_xpu, - dim, - in->data(), - nullptr /* index */, - out->mutable_data(TARGET(kXPU))); CHECK_EQ(r, 0); } diff --git a/lite/tests/kernels/sequence_pool_compute_test.cc b/lite/tests/kernels/sequence_pool_compute_test.cc index f987fb28022..3db7e220fb7 100644 --- a/lite/tests/kernels/sequence_pool_compute_test.cc +++ b/lite/tests/kernels/sequence_pool_compute_test.cc @@ -162,7 +162,11 @@ void test_sequence_pool(Place place) { for (auto h : {1, 3, 4}) { for (auto w : {1, 3, 4}) { for (auto pool_type : +#if defined(LITE_WITH_XPU) && !defined(LITE_WITH_XTCL) + {"SUM", "MAX", "FIRST", "LAST"}) { +#else {"SUM", "AVERAGE", "SQRT", "MAX", "MIN", "FIRST", "LAST"}) { +#endif for (int seq_num : {1, 3, 5}) { std::vector> lod; lod.resize(1); @@ -185,10 +189,12 @@ TEST(SequencePool, precision) { // #ifdef LITE_WITH_X86 // Place place(TARGET(kX86)); // #endif -#ifdef LITE_WITH_ARM +#if defined(LITE_WITH_XPU) && !defined(LITE_WITH_XTCL) + Place place(TARGET(kXPU)); +#elif defined(LITE_WITH_ARM) Place place(TARGET(kARM)); - test_sequence_pool(place); #endif + test_sequence_pool(place); } } // namespace lite