Skip to content

Commit

Permalink
[Arith][SVE] Add rewrite rules for indices split by scalable expressi…
Browse files Browse the repository at this point in the history
…ons (#17046)

This commit introduces rewrite rules for indices which can arise from splitting axes by scalable factors (e.g. `xo, xi = sch.split(x, factors = [None, 8 * T.vscale()])`):

```
(v_x_o * T.Cast("int64", T.vscale()) * T.int64(8) + v_x_i) // (T.Cast("int64", T.vscale()) * T.int64(8)) == v_x_o
(v_x_o * T.Cast("int64", T.vscale()) * T.int64(8) + v_x_i) % (T.Cast("int64", T.vscale()) * T.int64(8)) == v_x_i
```

The rewrites help prove checks needed by `sch.tensorize()` (e.g. CompareBufferRegion).
  • Loading branch information
Anndrey24 authored Jun 7, 2024
1 parent 1d761da commit 5d077c5
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1136,8 +1136,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
x + floordiv(y, z), CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF(matches_one_of(floordiv(y + x * z, z), floordiv(y + z * x, z)),
floordiv(y, z) + x, CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv(x * z * c1 + y, z * c1), x + floordiv(y, z * c1),
CanProveGreaterEqual(z.Eval() * c1.Eval(), 0));

TVM_TRY_REWRITE_IF(floordiv(x - floormod(x, c1), c1), floordiv(x, c1), c1.Eval()->value != 0);

// Scalable divisor
TVM_TRY_REWRITE_IF(floordiv(x, y), ZeroWithTypeLike(x),
ContainsVscaleCall(y.Eval()) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0) && CanProve(x.Eval() < y.Eval()));
}
return ret;
}
Expand Down Expand Up @@ -1230,6 +1237,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
ZeroWithTypeLike(x),
CanProveEqual(y.Eval() - z.Eval(), 0) || CanProveEqual(y.Eval() + z.Eval(), 0));

TVM_TRY_REWRITE_IF(floormod(x * z * c1 + y, z * c1), floormod(y, z * c1),
CanProveGreaterEqual(z.Eval() * c1.Eval(), 0));

// Scalable divisor
TVM_TRY_REWRITE_IF(floormod(x, y), x,
ContainsVscaleCall(y.Eval()) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0) && CanProve(x.Eval() < y.Eval()));

if (floormod(x, c1).Match(ret)) {
int64_t c1val = c1.Eval()->value;
if (c1val > 0) {
Expand Down
2 changes: 2 additions & 0 deletions src/arith/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
// TODO(tqchen) refer back to super-analyzer.
return TryCompare(x, val) == CompareResult::kEQ;
}
// Whether x is true
bool CanProve(const PrimExpr& x) { return analyzer_->CanProve(x); }

// Recursive rewrite x
// we limit maximum depth of recursive rewrite allowed to
Expand Down
8 changes: 8 additions & 0 deletions tests/python/arith/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ class TestFloordivIndex(BaseCompare):
TestCase(fld(x * y, y), x, y >= 0),
TestCase(fld(y * x, y), x, y >= 0),
TestCase(fld(x * z + y, z), x + fld(y, z), z >= 0),
TestCase(fld(x * z * 2 + y, z * 2), x + fld(y, z * 2), z * 2 >= 0),
TestCase(fld(z * x + y, z), x + fld(y, z), z >= 0),
TestCase(fld(y + x * z, z), fld(y, z) + x, z >= 0),
TestCase(fld(y + z * x, z), fld(y, z) + x, z >= 0),
Expand Down Expand Up @@ -616,6 +617,7 @@ class TestFloormodIndex(BaseCompare):
TestCase(flm(x + y * (-10), 2), flm(x, 2)),
TestCase(flm(x * 32 + y, 64), flm(x, 2) * 32 + y, [y >= 0, y < 32]),
TestCase(flm(x * 32 - y, 64), flm(x * 32 - y, 64), [y >= 0, y < 32]),
TestCase(flm(x * z * 2 + y, z * 2), flm(y, z * 2), z * 2 >= 0),
# NOTE: the followng case is covered by canonical simplify
# long range simplifcation in general can be covered by canonical simplify
# TestCase(flm(x * 10 + 1 + y * 2 + 2, 2), 1),
Expand Down Expand Up @@ -832,6 +834,12 @@ class TestScalableIndex(BaseCompare):
x + tir.vscale() * 4 - flm(4, tir.vscale() * 4),
),
TestCase(tvm.te.max(tir.vscale() * x, tir.vscale() * y), tir.vscale() * x, x > y),
# FloorDiv
TestCase(fld(x * tir.vscale() * 4 + y, tir.vscale() * 4), x + fld(y, tir.vscale() * 4)),
TestCase(fld(x, tir.vscale() * 4), 0, [x >= 0, x < tir.vscale() * 4]),
# FloorMod
TestCase(flm(x * tir.vscale() * 4 + y, tir.vscale() * 4), flm(y, tir.vscale() * 4)),
TestCase(flm(x, tir.vscale() * 4), x, [x >= 0, x < tir.vscale() * 4]),
)

def test_simplify(self, test_case):
Expand Down

0 comments on commit 5d077c5

Please sign in to comment.