From 376be9a708ab25ce4b2175a243c721ae5e24a3e7 Mon Sep 17 00:00:00 2001 From: pariksheet Date: Mon, 13 Aug 2018 15:07:07 +0530 Subject: [PATCH] Split_indices negative axis added --- topi/include/topi/transform.h | 5 +++++ topi/tests/python_cpp/test_topi_transform.py | 1 + 2 files changed, 6 insertions(+) diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 09af612b957b..245b38cfb63d 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -475,6 +475,11 @@ inline Array split_sections(const Tensor& x, int axis, std::string name = "tensor", std::string tag = kInjective) { + if (axis < 0) { + axis += static_cast(x->shape.size()); + } + CHECK_LT(axis, x->shape.size()) << "axis out of bounds"; + auto src_axis_size = static_cast(GetConstInt(x->shape[axis])); CHECK_GT(num_sections, 0) << "Slice count must be > 0"; diff --git a/topi/tests/python_cpp/test_topi_transform.py b/topi/tests/python_cpp/test_topi_transform.py index c8b7c3906caa..3f7bdbfdd499 100644 --- a/topi/tests/python_cpp/test_topi_transform.py +++ b/topi/tests/python_cpp/test_topi_transform.py @@ -340,6 +340,7 @@ def test_concatenate(): def test_split(): verify_split((2, 12, 3), 3, 1) + verify_split((2, 12, 3), 3, -1) verify_split((2, 12, 3), [2, 4], 1) verify_split((10, 12, 24), [5, 7, 9], -1)