Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
Fix apache#12039‘s broken cases (apache#12143)
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif authored and xinetzone committed Nov 25, 2022
1 parent 608e676 commit 40cf6c8
Show file tree
Hide file tree
Showing 8 changed files with 281 additions and 163 deletions.
91 changes: 66 additions & 25 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,12 @@ class IterMapRewriter : public ExprMutator {
using Parent = ExprMutator;

explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters,
bool simplify_trivial_iterators, Array<String>* errors)
: analyzer_(analyzer), errors_(*errors), padding_predicate_(const_false()) {
IterMapLevel check_level, bool simplify_trivial_iterators,
Array<String>* errors)
: analyzer_(analyzer),
check_level_(check_level),
errors_(*errors),
padding_predicate_(const_false()) {
for (auto kv : input_iters) {
const Var& var = kv.first;
const Range& vrng = kv.second;
Expand Down Expand Up @@ -419,6 +423,8 @@ class IterMapRewriter : public ExprMutator {

// Internal analyzer
Analyzer* analyzer_;
// Iter map check level
IterMapLevel check_level_;
// Error messages for each unresolved expression.
Array<String>& errors_;
// The var map
Expand Down Expand Up @@ -651,7 +657,7 @@ class IterMapRewriter : public ExprMutator {
if (predicate_induced_max.defined())
predicate_induced_max = predicate_induced_max.value() - base;
}
Optional<IterSumExpr> opt = TryFuseIters(expr);
Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
ICHECK(!opt.defined() || opt.value()->args.size() == 1);
// scale should be 1
if (opt.defined() && is_one(opt.value()->args[0]->scale)) {
Expand Down Expand Up @@ -702,7 +708,7 @@ class IterMapRewriter : public ExprMutator {
IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
// We are normalizing a regular iter
if (expr->args.size() < 1) return expr;
Optional<IterSumExpr> opt = TryFuseIters(expr);
Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
if (opt.defined()) {
return opt.value();
} else {
Expand Down Expand Up @@ -735,9 +741,10 @@ class IterMapRewriter : public ExprMutator {
* return a corresponding IterSumExpr with extra offset if needed.
* Try to normalize IterSum into a fused IterMark
* \param expr The input sum.
* \param check_level The check level if iter mapping.
* \return The sum with the fused IterMark and extra offset if succeed.
*/
Optional<IterSumExpr> TryFuseIters(IterSumExpr expr) {
Optional<IterSumExpr> TryFuseIters(IterSumExpr expr, IterMapLevel check_level) {
// select the iterators in order
std::vector<bool> visited(expr->args.size(), false);
std::vector<IterSplitExpr> flattened_iters, grouped_iters;
Expand All @@ -758,14 +765,42 @@ class IterMapRewriter : public ExprMutator {
}
// check if it can be remapped into a fused pattern.
PrimExpr expected_extra_base = 0;
PrimExpr tail_extent = 0;
PrimExpr expected_scale = base_scale.value();
for (size_t i = 0; i < expr->args.size();) {
// find j such that expr->args[j] has expected scale
size_t j = i == 0 ? base_index : 0;
for (; j < expr->args.size(); ++j) {
if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale, expected_scale)) break;
// find position such that expr->args[j] match expected scale
int j = i == 0 ? base_index : expr->args.size() - 1;

size_t matched_pos = expr->args.size();
PrimExpr matched_scale{nullptr};
bool is_exact_match{false};

for (; j >= 0; --j) {
if (visited[j]) {
continue;
}
const PrimExpr& cur_scale = expr->args[j]->scale;

// for bijective mapping, the matched scale must equal to expected scale
if (analyzer_->CanProveEqual(cur_scale, expected_scale)) {
matched_pos = j;
matched_scale = cur_scale;
is_exact_match = true;
break;
}
if (check_level != IterMapLevel::Bijective && base_scale.value()->value == 1) {
// find the closest scale which is less or equal to expected scale
if (analyzer_->CanProveGreaterEqual(expected_scale - cur_scale, 0) &&
analyzer_->CanProveGreaterEqual(cur_scale, 0)) {
if (matched_pos == expr->args.size() ||
analyzer_->CanProveLess(matched_scale - cur_scale, 0)) {
matched_pos = j;
matched_scale = cur_scale;
}
}
}
}
if (j == expr->args.size()) {
if (matched_pos == expr->args.size()) {
return NullOpt;
}
// look for the longest constrained iter started from expr->args[j]
Expand All @@ -775,8 +810,8 @@ class IterMapRewriter : public ExprMutator {
// otherwise we expect the scale of i to be 2*5=10
Optional<IterSumExpr> constraint_to_match;
for (const IterSumExpr& iter : constrained_iters_flattened_) {
if (IterSplitEqual(expr->args[j], iter->args.back(), false)) {
// find a predicate started from expr->args[j]
if (IterSplitEqual(expr->args[matched_pos], iter->args.back(), false)) {
// find a predicate started from match position
if (!constraint_to_match ||
constraint_to_match.value()->args.size() < iter->args.size()) {
constraint_to_match = iter;
Expand All @@ -793,7 +828,7 @@ class IterMapRewriter : public ExprMutator {
size_t k = 0;
for (; k < expr->args.size(); ++k) {
if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) {
if (analyzer_->CanProveEqual((*it)->scale * expected_scale, expr->args[k]->scale))
if (analyzer_->CanProveEqual((*it)->scale * matched_scale, expr->args[k]->scale))
break;
}
}
Expand All @@ -806,20 +841,25 @@ class IterMapRewriter : public ExprMutator {
auto iter = sum_fuse_map_.find(constraint_to_match.value());
ICHECK(iter != sum_fuse_map_.end());
const IterMarkWithOffset& iter_matched = iter->second;
grouped_iters.emplace_back(iter_matched.mark, expected_scale);
expected_extra_base += iter_matched.offset * expected_scale;
expected_scale *= iter_matched.mark->extent;
grouped_iters.emplace_back(iter_matched.mark, div(matched_scale, base_scale.value()));
expected_extra_base += iter_matched.offset * matched_scale;
if (!is_exact_match) {
tail_extent += expected_scale - matched_scale;
}
expected_scale = matched_scale * iter_matched.mark->extent;
// move forward
i += constraint_to_match.value()->args.size();
} else {
// constraint_to_match not found, skip this iterator
visited[j] = true;
IterSplitExpr arg = expr->args[j];
arg.CopyOnWrite()->scale =
analyzer_->Simplify(div(expr->args[j]->scale, base_scale.value()));
visited[matched_pos] = true;
IterSplitExpr arg = expr->args[matched_pos];
arg.CopyOnWrite()->scale = analyzer_->Simplify(div(arg->scale, base_scale.value()));
flattened_iters.push_back(arg);
grouped_iters.push_back(arg);
expected_scale *= expr->args[j]->extent;
if (!is_exact_match) {
tail_extent += expected_scale - matched_scale;
}
expected_scale = matched_scale * expr->args[matched_pos]->extent;
++i;
}
}
Expand All @@ -843,7 +883,8 @@ class IterMapRewriter : public ExprMutator {
expr->base + expected_extra_base);
} else {
// new iter, form a new mark
IterMark mark = IterMark(structured_form, div(expected_scale, base_scale.value()));
IterMark mark =
IterMark(structured_form, div(expected_scale, base_scale.value()) + tail_extent);
sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0);
flattened_map_[structured_form] = flattened_form;
return IterSumExpr({IterSplitExpr(mark, base_scale.value())},
Expand Down Expand Up @@ -1086,8 +1127,8 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range
constraints.begin(), constraints.end(),
[](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });

IterMapRewriter rewriter(analyzer, constrained_input_iters, simplify_trivial_iterators,
&result->errors);
IterMapRewriter rewriter(analyzer, constrained_input_iters, check_level,
simplify_trivial_iterators, &result->errors);
// Step0.0: rewrite constraints in the order from size-small ones to size-big ones
for (const IterConstraint& constraint : constraints) {
auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
Expand Down Expand Up @@ -1281,7 +1322,7 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o
} else if (sum->args.size() == 1) {
return sum;
}
auto opt_fused = TryFuseIters(sum);
auto opt_fused = TryFuseIters(sum, check_level_);
if (!opt_fused) {
ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(original_dividend)
<< ", can't be written as a single fused IterSum";
Expand Down
7 changes: 2 additions & 5 deletions tests/python/unittest/test_arith_intset.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,6 @@ def do_test_point_access(point, predicates, var_dom, expect):


def test_region_lower_bound_unfusable():
# This test is designed to trigger an error in DetectIterMap,
# resulting from a numerator which required multiple input
# variables. The bug resulted in an exception being thrown,
# rather than a return value of None.
var_dom = {
tvm.tir.Var("i", "int32"): tvm.ir.Range(8),
tvm.tir.Var("j", "int32"): tvm.ir.Range(4),
Expand All @@ -336,7 +332,8 @@ def test_region_lower_bound_unfusable():
tvm.ir.Range.from_min_extent((i + j) // 2, 1),
]
result = tvm.arith.estimate_region_lower_bound(region, var_dom, predicate=True)
assert result is None
assert result[0].min_value == 0
assert result[0].max_value == 5


def test_union_lower_bound():
Expand Down
58 changes: 57 additions & 1 deletion tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def assert_iter_sum_pattern(
)
indices = res.indices
assert len(indices) == len(keys), res.errors
print(indices)
for i, input_iter in enumerate(keys):
spec = expect_dict[input_iter]
(
Expand Down Expand Up @@ -446,6 +445,13 @@ def test_predicate():
predicate=xo * 129 + xi < 128,
)

# strided iteration predicate
assert_iter_sum_pattern(
{xo * 16 + xi * 4: (10, 0, 4)},
var_dom([(xo, 3), (xi, 4)]),
predicate=xo * 4 + xi < 10,
)


def convert_division(divisions):
if divisions is None or len(divisions) == 0:
Expand Down Expand Up @@ -1010,5 +1016,55 @@ def test_padding():
assert_iter_sum_failure({flm(x, 16)}, var_dom([(x, 3)]))


def test_overlapped_fuse():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.Var("z", "int32")
a = tvm.tir.Var("x", "int32")
b = tvm.tir.Var("y", "int32")

# non-bijective fuse of two
assert_iter_sum_pattern(
{
x * 7 + y: (22, 0, 1),
},
var_dom([(x, 3), (y, 8)]),
check_level="surjective",
)
assert_iter_sum_failure([x * 7 + y], var_dom([(x, 3), (y, 8)]), check_level="bijective")

# non-bijective fuse of three
assert_iter_sum_pattern(
{
x * 18 + y * 7 + z: (40, 0, 1),
},
var_dom([(x, 2), (y, 3), (z, 8)]),
check_level="surjective",
)
assert_iter_sum_failure([x * 7 + y], var_dom([(x, 2), (y, 3), (z, 8)]), check_level="bijective")

# negative scale fusion is not allowed
assert_iter_sum_failure([x * -7 + y], var_dom([(x, 3), (y, 8)]), check_level="surjective")
assert_iter_sum_failure([x * 7 - y], var_dom([(x, 3), (y, 8)]), check_level="surjective")

# with predicate
assert_iter_sum_pattern(
{
a * 40 + b * 20 + x * 18 + y * 3 + z: (125, 6, 1),
},
var_dom([(a, 3), (b, 2), (x, 2), (y, 6), (z, 8)]),
predicate=tvm.tir.all(z < 4, 1 < x * 6 + y, x * 6 + y < 10),
check_level="surjective",
)

# stride=1 kernel
assert_iter_sum_pattern(
{x + a: (230, 0, 1)}, var_dom([(x, 224), (a, 7)]), check_level="surjective"
)

# do not allow both strided and overlapped
assert_iter_sum_failure([5 * x + 2 * y], var_dom([(x, 4), (y, 3)]), check_level="surjective")


if __name__ == "__main__":
tvm.testing.main()
Loading

0 comments on commit 40cf6c8

Please sign in to comment.