Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Fix solve inequality of unbound var ranges #14582

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions src/arith/int_constraints.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,26 +153,35 @@ Range IntGroupBounds::FindBestRange(const Map<Var, Range>& vranges_addl) const {

for (const PrimExpr& low : lowers) {
for (const PrimExpr& upp : uppers) {
PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, coef), 3);
// Since diff may depend on some other variables, we compute its overapproximation
PrimExpr diff_over_1 = analyzer.Simplify(EvalSet(diff_1, var_intsets).max(), 3);
Optional<PrimExpr> diff_over;
PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, coef), 3);
IntSet diff_set1 = EvalSet(diff_1, var_intsets);
if (diff_set1.HasUpperBound()) {
diff_over = analyzer.Simplify(diff_set1.max(), 3);
}

// low is the lower bound for v*coef, but we need the lower bound for v.
// We use rounding-up division to compute it. Since we want to use a single formula
PrimExpr low_divided = analyzer.Simplify(floordiv(low + coef - 1, coef), 3);

// Compute another difference which may be more precise (or not).
PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, coef) - low_divided, 3);
PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, var_intsets).max(), 3);

PrimExpr diff_over =
analyzer.CanProve(diff_over_2 - diff_over_1 < 0) ? diff_over_2 : diff_over_1;
IntSet diff_set2 = EvalSet(diff_2, var_intsets);
if (diff_set2.HasUpperBound()) {
PrimExpr diff_over_2 = analyzer.Simplify(diff_set2.max(), 3);
diff_over = diff_over.defined() ? (analyzer.CanProve(diff_over_2 - diff_over.value() < 0)
? diff_over_2
: diff_over.value())
: diff_over_2;
}

// If it is provable that the new one is strictly better than the current best one,
// then replace it. Note that we are biased towards earlier pairs which should be simpler.
if (!best_diff_over.defined() || analyzer.CanProve(diff_over - best_diff_over < 0)) {
if (diff_over.defined() && (!best_diff_over.defined() ||
analyzer.CanProve(diff_over.value() - best_diff_over < 0))) {
best_lower = low_divided;
best_diff_over = diff_over;
best_diff_over = diff_over.value();
}
}
}
Expand Down
15 changes: 15 additions & 0 deletions tests/python/unittest/test_arith_solve_linear_inequality.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,5 +196,20 @@ def test_no_solution():
assert not rel


def test_unbound_var_range():
x = te.var("x0")
free_var = te.var("fv")
vranges = {x: tvm.ir.Range.from_min_extent(0, tvm.tir.Cast("int32", 1 + tvm.tir.log(free_var)))}
problem = [x > 3]
solution = arith.solve_linear_inequalities(
problem,
[x],
vranges,
)
assert len(solution.variables) == 1
assert len(solution.ranges) == 0
assert len(solution.relations) == 3


if __name__ == "__main__":
tvm.testing.main()