Skip to content

Commit

Permalink
[Arith] Simplifications for floormod(x, 2) (#13936)
Browse files Browse the repository at this point in the history
* [Arith] Simplifications for floormod(x, 2)

Because `floormod(x,2)` has only two possible values, it can be
simplified more aggressively than most FloorMod expressions.  The
additional simplifications are derived from `floormod(x,2) +
floormod(x+1,2) == 1`, which is true for denominator `2`, along with
the usual `floordiv(x,2)*2 + floormod(x,2) == x`, which is true for all
denominators.

This initially arose from an index expression `floormod(x + 1, 2) * 8192`,
for `x ∈ [0, 2)`.  This commit allows the expression to be re-written as
`x * (-8192) + 8192` and recognized as a strided access.
  • Loading branch information
Lunderberg authored Apr 4, 2023
1 parent f5db8b7 commit dba987c
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 52 deletions.
5 changes: 5 additions & 0 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,11 @@ class IterMapRewriter : public ExprMutator {
PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs);

static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
if (sign < 0 && is_const_int(rhs->extent, 2)) {
lhs->base -= rhs->scale;
sign = 1;
}

tir::ExprDeepEqual equal;
for (size_t i = 0; i < lhs->args.size(); ++i) {
IterSplitExpr lvalue = lhs->args[i];
Expand Down
38 changes: 33 additions & 5 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2),
c2.Eval()->value > 0);

TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) + floormod(x, 2), floordiv(x + 1, 2));

// canonicalization rule
// will try rewrite again after canonicalization.

TVM_TRY_RECURSIVE_REWRITE(matches_one_of(x + (c1 - y), (c1 - y) + x), (x - y) + c1);
TVM_TRY_RECURSIVE_REWRITE(matches_one_of(x + c1 + y, x + (c1 + y)), (x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE(matches_one_of((x + c1) + y, x + (c1 + y), x + (y + c1)),
(x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x + max(y, z), max(y, z) + x);
TVM_TRY_RECURSIVE_REWRITE(x + min(y, z), min(y, z) + x);

Expand Down Expand Up @@ -456,6 +460,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c1 - x, 0 - floormod(x - y, c1) - y,
c1.Eval()->value != 0);

TVM_TRY_RECURSIVE_REWRITE(
floordiv(x + c1, 2) - floordiv(x + c2, 2),
floormod(x, 2) * (floormod(c1, 2) - floormod(c2, 2)) + (floordiv(c1, 2) - floordiv(c2, 2)));
TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) - floordiv(x + c2, 2),
floormod(x, 2) * (0 - floormod(c2, 2)) - floordiv(c2, 2));
TVM_TRY_RECURSIVE_REWRITE(floordiv(x + c1, 2) - floordiv(x, 2),
floormod(x, 2) * floormod(c1, 2) + floordiv(c1, 2));

TVM_TRY_REWRITE_IF(
x * c2 - floordiv(x, c1) * c3, floormod(x, c1) * c2,
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
Expand All @@ -475,6 +487,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
floordiv(x - y, c1) * c3 - x * c2, (0 - floormod(x - y, c1) - y) * c2,
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);

TVM_TRY_RECURSIVE_REWRITE(floordiv(x + 1, 2) - floormod(x, 2), floordiv(x, 2));

TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x + c2, c3),
floordiv(floormod(x + floormod(c2, c3), c3) + (c1 - c2), c3),
c3.Eval()->value > 0);
Expand All @@ -485,6 +499,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
// will try rewrite again after canonicalization.
TVM_TRY_REWRITE(x - c1, x + (0 - c1));
TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x - (y + c1), (x - y) + (0 - c1));
TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y);
TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1));
} else if (op->dtype.is_float()) {
Expand Down Expand Up @@ -864,6 +879,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
TVM_TRY_REWRITE(floordiv(x, x), OneWithTypeLike(x));
TVM_TRY_REWRITE(matches_one_of(floordiv(x * c1, x), floordiv(c1 * x, x)), c1);

TVM_TRY_REWRITE(floordiv(floormod(x, 2) + 1, 2), floormod(x, 2));

// Rules involving 2-operands.
TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
Expand Down Expand Up @@ -975,6 +992,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2),
c2.Eval()->value > 0);

TVM_TRY_RECURSIVE_REWRITE_IF(floormod(x + c1, 2), floormod(x, 2) * (-1) + 1,
floormod(c1.Eval()->value, 2) == 1);
TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

Expand All @@ -985,12 +1004,21 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {

TVM_TRY_REWRITE(matches_one_of(floormod(x * y, y), floormod(y * x, y)), ZeroWithTypeLike(y));

// try modular analysis
if (floormod(x, c1).Match(ret)) {
ModularSet mod = analyzer_->modular_set(x.Eval());
int64_t c1val = c1.Eval()->value;
if (mod->coeff % c1val == 0 && c1val > 0) {
return floormod(mod->base, c1).Eval();
if (c1val > 0) {
// try modular analysis
ModularSet mod = analyzer_->modular_set(x.Eval());
if (mod->coeff % c1val == 0) {
return floormod(mod->base, c1).Eval();
}

// floormod(x,c1) is a no-op when x is already in the
// appropriate range.
ConstIntBound bound = analyzer_->const_int_bound(x.Eval());
if (bound->min_value >= 0 && bound->max_value < c1val) {
return x.Eval();
}
}
}
}
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,18 @@ def test_compound():
assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, var_dom([(x, 10), (y, 9)]))


def test_compound_floormod_two():
x = tvm.tir.Var("x", "int32")
fld = tvm.tir.floordiv
flm = tvm.tir.floormod

# extent of 2 are normalized to positive scale
assert_iter_sum_pattern(
expect_dict={fld(x, 2) * 2 - flm(x, 2) + 1: (8, 0, 1)},
dom_map=var_dom([(x, 8)]),
)


def test_predicate():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
Expand Down
33 changes: 33 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,39 @@ class TestFloormodIndex(BaseCompare):
)


class TestFloorModTwo(BaseCompare):
"""Special-case simplifications for FloorMod(expr,2)
Because FloorMod(expr,2) has only two possible values, it can be
simplified more aggressively than most FloorMod expressions. Some
of these have analogues for other denominators (e.g. x%3 + (x+1)%3
+ (x+2)%3 == 0 + 1 + 2), but they don't appear as often and
require identifying more related terms in order to apply.
(x + c1)//2 - (x+c2)//2 => (x%2)*( c1%2 - c1%2 ) + (c1//2 - c2//2)
"""

x, y, z = te.var("x"), te.var("y"), te.var("z")
test_case = tvm.testing.parameter(
# Removing offsets from floormod
TestCase(flm(x + 1, 2), flm(x, 2) * (-1) + 1),
TestCase(flm(x + 5, 2), flm(x, 2) * (-1) + 1),
TestCase(flm(x, 2) + flm(x + 1, 2), 1),
TestCase(flm(x + 1, 2) + flm(x, 2), 1),
# Difference of floordiv yields floormod
TestCase(fld(x + 1, 2) - fld(x, 2), flm(x, 2)),
TestCase(fld(x, 2) - fld(x - 1, 2), flm(x, 2) * -1 + 1),
TestCase(fld(x + 5, 2) - fld(x - 2, 2), flm(x, 2) + 3),
TestCase(fld(x + 5, 2) - fld(x - 3, 2), 4),
TestCase(fld(flm(x, 2) + 1, 2), flm(x, 2)),
# Sum of floordiv and floormod to yield floordiv
TestCase(fld(x + 1, 2) - flm(x, 2), fld(x, 2)),
TestCase(fld(x, 2) + flm(x, 2), fld(x + 1, 2)),
# Removal of floormod where possible
TestCase(flm(x + 1, 2) * 8192, x * (-8192) + 8192, [x >= 0, x < 2]),
)


class TestMinIndex(BaseCompare):
x, y, z = te.var("x"), te.var("y"), te.var("z")
test_case = tvm.testing.parameter(
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_meta_schedule_space_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5
for ax0_ax1_ax2_ax3_fused in T.serial((i4_0 % 2 + 1) // 2 * 96 + 96):
with T.block("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 64 + i4_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * ((i4_0 % 2 + 1) // 2 + 1)) // 96)
v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 64 + i4_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * (i4_0 % 2 + 1)) // 96)
v2 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused % 64 // 16 + ax0_ax1_ax2_ax3_fused % 96 // 32)
v3 = T.axis.spatial(512, i6_0 * 32 + ax0_ax1_ax2_ax3_fused % 32)
T.reads(inputs[v0, v1 - 1, v2 - 1, v3])
Expand Down
Loading

0 comments on commit dba987c

Please sign in to comment.