Skip to content

Commit

Permalink
fix concat when axis < 0 (#5917) (#5920)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang authored Apr 20, 2021
1 parent deeb04e commit fdab3f8
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 20 deletions.
2 changes: 1 addition & 1 deletion lite/kernels/arm/concat_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void ConcatCompute::Run() {
axis = axis_tensor_data[0];
}
if (axis < 0) {
axis += inputs[0]->dims().size();
axis += static_cast<int>(inputs[0]->dims().size());
}

lite_api::PrecisionType type = PRECISION(kUnk);
Expand Down
9 changes: 6 additions & 3 deletions lite/kernels/x86/concat_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,17 @@ class ConcatCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
return;
}

int64_t axis = static_cast<int64_t>(param.axis);
int axis = param.axis;
auto* axis_tensor = param.axis_tensor;
if (axis_tensor != nullptr) {
auto* axis_tensor_data = axis_tensor->template data<int>();
axis = static_cast<int64_t>(axis_tensor_data[0]);
axis = axis_tensor_data[0];
}

const auto& x_dims = param.x[0]->dims();
if (axis < 0) {
axis += static_cast<int>(x_dims.size());
}

auto* out = param.output;
T* output_data = param.output->template mutable_data<T>();

Expand Down
6 changes: 4 additions & 2 deletions lite/kernels/xpu/concat_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ void ConcatCompute<InType>::Run() {

auto ins = param.x;
auto out = param.output;
int64_t axis = param.axis;
int64_t axis = param.axis < 0
? param.axis + static_cast<int>(ins[0]->dims().size())
: param.axis;

std::vector<const float*> x_list;
std::vector<std::vector<int>> xdims_list;
Expand Down Expand Up @@ -69,7 +71,7 @@ REGISTER_LITE_KERNEL(concat,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::ConcatCompute<float>,
concat_fp32)
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFloat))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
Expand Down
2 changes: 1 addition & 1 deletion lite/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ bool ConcatOpLite::InferShapeImpl() const {
axis = axis_tensor_val[0];
}
if (axis < 0) {
axis += inputs[0]->dims().size();
axis += static_cast<int>(inputs[0]->dims().size());
}

auto out_dims = inputs[0]->dims();
Expand Down
28 changes: 15 additions & 13 deletions lite/tests/kernels/concat_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,23 +74,24 @@ class ConcateComputeTester : public arena::TestCase {
x_vct.push_back(scope->FindTensor(name));
}

int axis = axis_ < 0 ? axis_ + static_cast<int>(x_dims_.size()) : axis_;
auto* out = scope->NewTensor(out_);
DDim output_dims = infer_shape(x_vct, axis_);
DDim output_dims = infer_shape(x_vct, axis);
out->Resize(output_dims);
auto* output_data = out->mutable_data<float>();

int num = x_vct.size();
int rows = 1;
auto dim_0 = x_vct[0]->dims();
for (int i = 0; i < axis_; ++i) {
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;

std::vector<int> input_cols(x_vct.size());
for (int i = 0; i < num; ++i) {
int input_i_numel = x_vct[i]->dims().size() == 0 ? 0 : 1;
for (int didx = 0; didx < x_vct[i]->dims().size(); ++didx) {
for (size_t didx = 0; didx < x_vct[i]->dims().size(); ++didx) {
input_i_numel *= x_vct[i]->dims()[didx];
}
int t_cols = input_i_numel / rows;
Expand Down Expand Up @@ -135,21 +136,27 @@ class ConcateComputeTester : public arena::TestCase {

if (is_use_axis_tensor_) {
SetCommonTensor(axis_tensor_, DDim({1}), &axis_);
LOG(INFO) << "set axis tensor";
}
}
};

TEST(Concat, precision) {
LOG(INFO) << "test concat op, kARM";
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_NPU)
std::vector<int> axes{-1, 1, 2};
std::vector<bool> use_axis_tensor{false, true};
#if defined(LITE_WITH_XPU) && !defined(LITE_WITH_XTCL)
place = TARGET(kXPU);
use_axis_tensor = std::vector<bool>{false};
#elif defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // use fp16 in npu
axes = std::vector<int>{1, 2};
use_axis_tensor = std::vector<bool>{false};
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
axes = std::vector<int>{1, 2};
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_X86)
Expand All @@ -158,13 +165,8 @@ TEST(Concat, precision) {
return;
#endif

for (int axis : {1, 2}) {
for (bool is_use_axis_tensor : {false, true}) {
#ifdef LITE_WITH_NPU
if (is_use_axis_tensor) continue;
#endif
LOG(INFO) << "axis:" << axis
<< ", is_use_axis_tensor:" << is_use_axis_tensor;
for (int axis : axes) {
for (bool is_use_axis_tensor : use_axis_tensor) {
std::unique_ptr<arena::TestCase> tester(
new ConcateComputeTester(place, "def", axis, is_use_axis_tensor));
arena::Arena arena(std::move(tester), place, abs_error);
Expand Down

0 comments on commit fdab3f8

Please sign in to comment.