Skip to content

Commit

Permalink
Fix for GRP sketch gen (apache#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
spectrometerHBH authored Jan 23, 2022
1 parent 5793c98 commit f748df4
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
10 changes: 10 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), x * floordiv(c1, c2) + floordiv(y, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), floordiv(x, floordiv(c2, c1)),
c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
c2.Eval()->value % c1.Eval()->value == 0 &&
CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0));

TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

Expand All @@ -782,6 +787,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(y, c2) + x * floordiv(c1, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(x, floordiv(c2, c1)),
c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
c2.Eval()->value % c1.Eval()->value == 0 &&
CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0));

TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

Expand Down
9 changes: 8 additions & 1 deletion tests/python/unittest/test_arith_intset.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,14 @@ def test_mod():
floordiv = tvm.te.floordiv
z = te.var("z")
ck.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 3))
ck.verify(flm(y, 8), {y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)}, (0, 7))
ck.verify(
flm(y, 8),
{y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)},
(
z * 8 + x * 4 - 8 * floordiv(z * 8 + x * 4, 8),
z * 8 + x * 4 + 3 - 8 * floordiv(z * 8 + x * 4, 8),
),
)
ck1 = IntSetChecker()
ck1.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 2))
ck1.verify(
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_tir_schedule_split_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def elementwise_split_with_predicate(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [128, 128, 128])
for i0, i1, i2, j0, j1, k0, k1 in T.grid(1000, 2, 3, 1, 129, 3, 43):
with T.block("B"):
T.where((i0 * 2 + i1) * 3 + i2 < 128 and j0 * 129 + j1 < 128 and k0 * 43 + k1 < 128)
T.where((i0 * 2 + i1) * 3 + i2 < 128 and j1 < 128 and k0 * 43 + k1 < 128)
vi = T.axis.S(128, i0 * 6 + i1 * 3 + i2)
vj = T.axis.S(128, j1)
vk = T.axis.S(128, k0 * 43 + k1)
Expand Down

0 comments on commit f748df4

Please sign in to comment.