diff --git a/include/tvm/topi/detail/constant_utils.h b/include/tvm/topi/detail/constant_utils.h index 49ce21b5732e..92ff3a4e3804 100644 --- a/include/tvm/topi/detail/constant_utils.h +++ b/include/tvm/topi/detail/constant_utils.h @@ -48,7 +48,8 @@ using namespace tvm::te; inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance(); } /*! - * \brief Test whether the given Array has every element as constant integer + * \brief Test whether the given Array has every element as constant integer. + * Undefined elements are also treat as constants. * * \param array the array to query * @@ -57,7 +58,7 @@ inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance array) { bool is_const_int = true; for (auto const& elem : array) { - is_const_int &= elem->IsInstance(); + is_const_int &= !elem.defined() || elem->IsInstance(); } return is_const_int; } diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index a04762f28feb..261fdf9970a3 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -612,6 +612,7 @@ inline Tensor strided_slice(const Tensor& x, const Array& begin, Array out_shape; if (!is_static) { + ICHECK_EQ(strides.size(), src_tensor_dim); for (size_t i = 0; i < src_tensor_dim; ++i) { out_shape.push_back(indexdiv(end[i] - begin[i], strides[i])); } diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 30434f6fd266..e0018ba0c0d3 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -817,6 +817,7 @@ def test_strided_slice(): verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1]) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3]) verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3]) + verify_strided_slice((3, 4, 3), [0, 0, 0], [None, None, None]) @tvm.testing.uses_gpu