diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index f1838f5a9099..6418f28cf87c 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1096,24 +1096,26 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, max(x, c2)), max(x, c2), c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value && - CanProveGreaterEqual(x.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 1)); TVM_TRY_REWRITE_IF(min(x, truncdiv(x + c1, c2) * c2), x, c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(max(x, c2), truncdiv(x + c1, c2) * c2), max(x, c2), c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value && - CanProveGreaterEqual(x.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 1)); // Divide up rounding: floor div TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, x), x, c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, max(x, c2)), max(x, c2), - c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value && + CanProveGreaterEqual(x.Eval(), 1)); TVM_TRY_REWRITE_IF(min(x, floordiv(x + c1, c2) * c2), x, c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(max(x, c2), floordiv(x + c1, c2) * c2), max(x, c2), - c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value && + CanProveGreaterEqual(x.Eval(), 1)); TVM_TRY_REWRITE_IF(min(x, floordiv(x, c2) * c2), floordiv(x, c2) * c2, c2.Eval()->value > 0); TVM_TRY_REWRITE_IF(min(floordiv(x, c2) * c2, x), floordiv(x, c2) * c2, c2.Eval()->value > 0); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 975af097c030..84340ec031c7 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -649,20 +649,23 @@ def test_min_index_simplify(): # truc div ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000)) ck.verify(tvm.te.min(tdiv(x + 3, 4) * 4, x), x) - ck.verify(tvm.te.min(tdiv(x + 3, 4) * 4, tvm.te.max(x, 4)), tvm.te.max(x, 4)) ck.verify(tvm.te.min(x, tdiv(x + 3, 4) * 4), x) - ck.verify(tvm.te.min(tvm.te.max(x, 4), tdiv(x + 3, 4) * 4), tvm.te.max(x, 4)) ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True) ck.verify(tvm.te.min(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.te.min(x, y), 10)) ck.verify(tvm.te.min(tdiv(x, (-10)), tdiv(y, (-10))), tdiv(tvm.te.max(x, y), (-10))) + ck.analyzer.update(x, tvm.arith.ConstIntBound(1, 1000), True) + ck.verify(tvm.te.min(tdiv(x + 3, 4) * 4, tvm.te.max(x, 4)), tvm.te.max(x, 4)) + ck.verify(tvm.te.min(tvm.te.max(x, 4), tdiv(x + 3, 4) * 4), tvm.te.max(x, 4)) # floor div ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True) ck.verify(tvm.te.min(fld(x + 3, 4) * 4, x), x) - ck.verify(tvm.te.min(fld(x + 3, 4) * 4, tvm.te.max(x, 4)), tvm.te.max(x, 4)) ck.verify(tvm.te.min(x, fld(x + 3, 4) * 4), x) ck.verify(tvm.te.min(x, fld(x, 4) * 4), fld(x, 4) * 4) + ck.analyzer.update(x, tvm.arith.ConstIntBound(1, 1000), True) + ck.verify(tvm.te.min(fld(x + 3, 4) * 4, tvm.te.max(x, 4)), tvm.te.max(x, 4)) ck.verify(tvm.te.min(tvm.te.max(x, 4), fld(x + 3, 4) * 4), tvm.te.max(x, 4)) + ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True) ck.verify(tvm.te.min(fld(x, 10), fld(y, 10)), fld(tvm.te.min(x, y), 10)) ck.verify(tvm.te.min(fld(x, (-10)), fld(y, (-10))), fld(tvm.te.max(x, y), (-10)))