Skip to content

Commit

Permalink
add tests for floordiv simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
Tristan Konolige committed Jan 31, 2022
1 parent 61efed9 commit 7928fc7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
TVM_TRY_REWRITE_IF(floordiv(x * y, y), x, CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv(y * x, y), x, CanProveGreaterEqual(y.Eval(), 0));

if ((floordiv(x, y)).Match(ret) && analyzer_->CanProve(x.Eval() < y.Eval())) {
if ((floordiv(x, y)).Match(ret) && analyzer_->CanProve(x.Eval() < y.Eval()) &&
analyzer_->CanProve(-x.Eval() < y.Eval())) {
return 0;
}

Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ def test_vector_simplify():
fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
) # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [0, 1, 1, 1]
a, b = te.var("a"), te.var("b")
ck.analyzer.update(a, tvm.arith.ConstIntBound(0, 99), override=True)
ck.analyzer.update(b, tvm.arith.ConstIntBound(100, 200), override=True)
ck.verify(fld(a, b), 0)
ck.verify(fld(a, b) + x, x)
ck.analyzer.update(a, tvm.arith.ConstIntBound(-99, 0), override=True)
ck.verify(fld(a, b), 0)
ck.verify(fld(a, b) + x, x)

# floor mod
ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")), flm(y, x).astype("int32x2"))
Expand Down

0 comments on commit 7928fc7

Please sign in to comment.