Skip to content

Commit

Permalink
[ARITH] Improve arith simplify to handle symbolic reshape pattern (#1…
Browse files Browse the repository at this point in the history
…5081)

This PR enhances arith simplify to handle symbolic reshape patterns.
Lift the CombineIters to callers of TryFuseIters so they can be used
in early return simplifications. Testcases are added.

Also updates a minor spelling issue in the testcase.
  • Loading branch information
tqchen authored Jun 14, 2023
1 parent 68ac909 commit 02dc191
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 17 deletions.
11 changes: 8 additions & 3 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,10 @@ class IterMapRewriter : public ExprMutator {
IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
// We are normalizing a regular iter
if (expr->args.size() < 1) return expr;
if (auto opt = TryCombineSplitFromSameSource(expr)) {
expr = opt.value();
if (expr->args.size() < 1) return expr;
}
Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
if (opt.defined()) {
return opt.value();
Expand Down Expand Up @@ -995,9 +999,6 @@ class IterMapRewriter : public ExprMutator {
* \return The sum with the fused IterMark and extra offset if succeed.
*/
Optional<IterSumExpr> TryFuseIters(IterSumExpr expr, IterMapLevel check_level) {
if (auto opt = TryCombineSplitFromSameSource(expr)) {
expr = opt.value();
}
// select the iterators in order
std::vector<bool> visited(expr->args.size(), false);
int base_index = FindBaseIter(expr, visited, NullOpt);
Expand Down Expand Up @@ -1553,6 +1554,10 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o
return IterSumExpr();
} else if (sum->args.size() == 1) {
return sum;
} else if (auto opt = TryCombineSplitFromSameSource(sum)) {
if (opt.value()->args.size() == 1) {
return opt.value();
}
}
auto opt_fused = TryFuseIters(sum, check_level_);
if (!opt_fused) {
Expand Down
41 changes: 27 additions & 14 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def assert_iter_sum_pattern(
tvm.ir.assert_structural_equal(sum_expr, expect_iter)


def assert_iter_map_simplfy(
def assert_iter_map_simplify(
expect_dict, dom_map, predicate=True, check_level="surjective", simplify_trivial_iterators=True
):
keys = list(expect_dict.keys())
Expand Down Expand Up @@ -1120,28 +1120,28 @@ def test_iter_map_simplify_symbolic_case():
def simple_fuse0(x):
return (x // n) * n + x % n

assert_iter_map_simplfy({simple_fuse0(x): x}, var_dom([(x, n * 32)]))
assert_iter_map_simplify({simple_fuse0(x): x}, var_dom([(x, n * 32)]))

assert_iter_map_simplfy({simple_fuse0(z): z}, var_dom([(x, n), (y, 32)]))
assert_iter_map_simplify({simple_fuse0(z): z}, var_dom([(x, n), (y, 32)]))

def fsymbolic_fuse0(x):
return ((x // (n * n)) % 32) * (n * n) + ((x // n) % n) * n + x % n

assert_iter_map_simplfy({fsymbolic_fuse0(x): x}, var_dom([(x, n * n * 32)]))
assert_iter_map_simplify({fsymbolic_fuse0(x): x}, var_dom([(x, n * n * 32)]))

assert_iter_map_simplfy({fsymbolic_fuse0(z): z}, var_dom([(x, n * n), (y, 32)]))
assert_iter_map_simplify({fsymbolic_fuse0(z): z}, var_dom([(x, n * n), (y, 32)]))

def fsymbolic_fuse1(x):
return ((x % (n * n * 32)) // (n * n) * n + (x % (n * n) // n)) * n + x % n

assert_iter_map_simplfy({fsymbolic_fuse1(x): x}, var_dom([(x, n * n * 32)]))
assert_iter_map_simplify({fsymbolic_fuse1(x): x}, var_dom([(x, n * n * 32)]))

assert_iter_map_simplfy({fsymbolic_fuse1(z): z}, var_dom([(x, n * n), (y, 32)]))
assert_iter_map_simplify({fsymbolic_fuse1(z): z}, var_dom([(x, n * n), (y, 32)]))

def fsymbolic_fuse2(i):
return (i // (n * n) * n + i % (n * n) // n) * n + i % n

assert_iter_map_simplfy({fsymbolic_fuse2(x): x}, var_dom([(x, n * n * 32)]))
assert_iter_map_simplify({fsymbolic_fuse2(x): x}, var_dom([(x, n * n * 32)]))


def test_iter_map_simplify_symbolic_predicate():
Expand All @@ -1155,21 +1155,34 @@ def simple_fuse0(x):
return (x // n) * n + x % n

z = x * 32 + y
assert_iter_map_simplfy(
assert_iter_map_simplify(
{simple_fuse0(z): z}, var_dom([(x, (n + 1) // 2), (y, 32)]), predicate=(z < n * 16)
)

def fsymbolic_fuse2(i):
return (i // (n * n) * n + i % (n * n) // n) * n + i % n

z = x * 64 + y
assert_iter_map_simplfy(
assert_iter_map_simplify(
{fsymbolic_fuse2(z): z},
var_dom([(x, (n * n + 1) // 2), (y, 64)]),
predicate=(z < n * n * 32),
)


def test_iter_map_simplify_symbolic_reshape():
n = tvm.tir.Var("n", "int64")
fused = tvm.tir.Var("fused", "int64")

ax0 = (fused // 4096) // n
ax1 = (fused // 4096) % n
ax2 = fused % 4096

rhs_index = ((ax2 // 4096 + ax0 * n + ax1) % n) * 4096 + ax2 % 4096

assert_iter_map_simplify({rhs_index: fused}, var_dom([(fused, n * 4096)]))


def test_iter_map_simplify_unit_loop_order():
"""Test itermap simplify"""
x = tvm.tir.Var("x", "int64")
Expand All @@ -1178,26 +1191,26 @@ def test_iter_map_simplify_unit_loop_order():

# trivial iterators can be found at any when comparing via scale
# ensure order unchange
assert_iter_map_simplfy(
assert_iter_map_simplify(
{x + y + z: x + y + z}, var_dom([(x, 1), (y, 1), (z, 1)]), simplify_trivial_iterators=False
)

# Even with simplifcation, it should follow the original order
assert_iter_map_simplfy(
assert_iter_map_simplify(
{x + y + (z // 4) * 4 + z % 4: z + x + y},
var_dom([(x, 1), (y, 1), (z, 32)]),
simplify_trivial_iterators=False,
)

assert_iter_map_simplfy(
assert_iter_map_simplify(
{y + 64 - x % 2 * 64: y + 64 - x % 2 * 64},
var_dom([(x, 6), (y, 64)]),
simplify_trivial_iterators=False,
)

# When we have iterators that have same scale but one of them come
# with unit extent, we should prioritize unit extent
assert_iter_map_simplfy(
assert_iter_map_simplify(
{x // 128 + y + z: y + x // 128 + z},
var_dom([(x, 128), (y, 128), (z, 1)]),
simplify_trivial_iterators=False,
Expand Down

0 comments on commit 02dc191

Please sign in to comment.