Skip to content

Commit

Permalink
[xpu] refine sequence_pool (#5781)
Browse files Browse the repository at this point in the history
  • Loading branch information
newway authored Mar 24, 2021
1 parent 9b0436c commit 03120a3
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 25 deletions.
65 changes: 42 additions & 23 deletions lite/kernels/xpu/sequence_pool_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,9 @@ void XPUSequencePoolCompute::Run() {

auto* in = param.X;
auto* out = param.Out;
float pad_value = param.pad_value;
std::string pool_type_str = param.pool_type;

auto dims = in->dims();
auto lod = in->lod();
dims[0] = lod[0].size() - 1;

xdnn::Pooling_t pool_type = xdnn::Pooling_t::MAX_WITHOUT_INDEX;
if (pool_type_str == "MAX") {
} else if (pool_type_str == "SUM") {
pool_type = xdnn::Pooling_t::SUM;
} else if (pool_type_str == "LAST") {
pool_type = xdnn::Pooling_t::LAST;
} else {
CHECK(false);
}

int num_seq = out->dims()[0];
int dim = out->numel() / num_seq;

Expand All @@ -62,16 +49,48 @@ void XPUSequencePoolCompute::Run() {
lod_cpu.get(),
in_lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
int r = 0;
if (pool_type_str == "MAX") {
r = xdnn::sequence_max_pool<float, int>(
ctx.GetRawContext(),
in->data<float>(),
lod_xpu,
out->mutable_data<float>(TARGET(kXPU)),
num_seq,
dim,
pad_value,
nullptr);
} else if (pool_type_str == "SUM") {
r = xdnn::sequence_sum_pool<float, int>(
ctx.GetRawContext(),
in->data<float>(),
lod_xpu,
out->mutable_data<float>(TARGET(kXPU)),
num_seq,
dim,
pad_value);
} else if (pool_type_str == "LAST") {
r = xdnn::sequence_last_pool<float, int>(
ctx.GetRawContext(),
in->data<float>(),
lod_xpu,
out->mutable_data<float>(TARGET(kXPU)),
num_seq,
dim,
pad_value);
} else if (pool_type_str == "FIRST") {
r = xdnn::sequence_first_pool<float, int>(
ctx.GetRawContext(),
in->data<float>(),
lod_xpu,
out->mutable_data<float>(TARGET(kXPU)),
num_seq,
dim,
pad_value);
} else {
CHECK(false) << " unsupported pool_type_str: " << pool_type_str;
}

int r =
xdnn::sequence_pooling_forward(ctx.GetRawContext(),
pool_type,
num_seq,
lod_xpu,
dim,
in->data<float>(),
nullptr /* index */,
out->mutable_data<float>(TARGET(kXPU)));
CHECK_EQ(r, 0);
}

Expand Down
10 changes: 8 additions & 2 deletions lite/tests/kernels/sequence_pool_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ void test_sequence_pool(Place place) {
for (auto h : {1, 3, 4}) {
for (auto w : {1, 3, 4}) {
for (auto pool_type :
#if defined(LITE_WITH_XPU) && !defined(LITE_WITH_XTCL)
{"SUM", "MAX", "FIRST", "LAST"}) {
#else
{"SUM", "AVERAGE", "SQRT", "MAX", "MIN", "FIRST", "LAST"}) {
#endif
for (int seq_num : {1, 3, 5}) {
std::vector<std::vector<uint64_t>> lod;
lod.resize(1);
Expand All @@ -185,10 +189,12 @@ TEST(SequencePool, precision) {
// #ifdef LITE_WITH_X86
// Place place(TARGET(kX86));
// #endif
#ifdef LITE_WITH_ARM
#if defined(LITE_WITH_XPU) && !defined(LITE_WITH_XTCL)
Place place(TARGET(kXPU));
#elif defined(LITE_WITH_ARM)
Place place(TARGET(kARM));
test_sequence_pool(place);
#endif
test_sequence_pool(place);
}

} // namespace lite
Expand Down

0 comments on commit 03120a3

Please sign in to comment.