Skip to content

Commit

Permalink
[TOPI] Treat undefined elements as constants in Array (#7232)
Browse files Browse the repository at this point in the history
* [TOPI] Treat undefined elements as constants in Array

* Add a checker

* fix

* add test case
  • Loading branch information
comaniac authored Jan 9, 2021
1 parent 701bcc2 commit 02ef6e6
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
5 changes: 3 additions & 2 deletions include/tvm/topi/detail/constant_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ using namespace tvm::te;
inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance<tvm::tir::IntImmNode>(); }

/*!
* \brief Test whether the given Array has every element as constant integer
* \brief Test whether the given Array has every element as constant integer.
* Undefined elements are also treat as constants.
*
* \param array the array to query
*
Expand All @@ -57,7 +58,7 @@ inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance<tvm::tir::IntImm
inline bool IsConstIntArray(Array<PrimExpr> array) {
bool is_const_int = true;
for (auto const& elem : array) {
is_const_int &= elem->IsInstance<tvm::tir::IntImmNode>();
is_const_int &= !elem.defined() || elem->IsInstance<tvm::tir::IntImmNode>();
}
return is_const_int;
}
Expand Down
1 change: 1 addition & 0 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,7 @@ inline Tensor strided_slice(const Tensor& x, const Array<PrimExpr>& begin,

Array<PrimExpr> out_shape;
if (!is_static) {
ICHECK_EQ(strides.size(), src_tensor_dim);
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(indexdiv(end[i] - begin[i], strides[i]));
}
Expand Down
1 change: 1 addition & 0 deletions tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,7 @@ def test_strided_slice():
verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3])
verify_strided_slice((3, 4, 3), [0, 0, 0], [None, None, None])


@tvm.testing.uses_gpu
Expand Down

0 comments on commit 02ef6e6

Please sign in to comment.