diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 84606bd01e06..4382a2ef8935 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -153,9 +153,13 @@ Range IntGroupBounds::FindBestRange(const Map& vranges_addl) const { for (const PrimExpr& low : lowers) { for (const PrimExpr& upp : uppers) { - PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, coef), 3); // Since diff may depend on some other variables, we compute its overapproximation - PrimExpr diff_over_1 = analyzer.Simplify(EvalSet(diff_1, var_intsets).max(), 3); + Optional diff_over; + PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, coef), 3); + IntSet diff_set1 = EvalSet(diff_1, var_intsets); + if (diff_set1.HasUpperBound()) { + diff_over = analyzer.Simplify(diff_set1.max(), 3); + } // low is the lower bound for v*coef, but we need the lower bound for v. // We use rounding-up division to compute it. Since we want to use a single formula @@ -163,16 +167,21 @@ Range IntGroupBounds::FindBestRange(const Map& vranges_addl) const { // Compute another difference which may be more precise (or not). PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, coef) - low_divided, 3); - PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, var_intsets).max(), 3); - - PrimExpr diff_over = - analyzer.CanProve(diff_over_2 - diff_over_1 < 0) ? diff_over_2 : diff_over_1; + IntSet diff_set2 = EvalSet(diff_2, var_intsets); + if (diff_set2.HasUpperBound()) { + PrimExpr diff_over_2 = analyzer.Simplify(diff_set2.max(), 3); + diff_over = diff_over.defined() ? (analyzer.CanProve(diff_over_2 - diff_over.value() < 0) + ? diff_over_2 + : diff_over.value()) + : diff_over_2; + } // If it is provable that the new one is strictly better than the current best one, // then replace it. Note that we are biased towards earlier pairs which should be simpler. - if (!best_diff_over.defined() || analyzer.CanProve(diff_over - best_diff_over < 0)) { + if (diff_over.defined() && (!best_diff_over.defined() || + analyzer.CanProve(diff_over.value() - best_diff_over < 0))) { best_lower = low_divided; - best_diff_over = diff_over; + best_diff_over = diff_over.value(); } } } diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py index dd2fbdf72b94..5285da12e75d 100644 --- a/tests/python/unittest/test_arith_solve_linear_inequality.py +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -196,5 +196,20 @@ def test_no_solution(): assert not rel +def test_unbound_var_range(): + x = te.var("x0") + free_var = te.var("fv") + vranges = {x: tvm.ir.Range.from_min_extent(0, tvm.tir.Cast("int32", 1 + tvm.tir.log(free_var)))} + problem = [x > 3] + solution = arith.solve_linear_inequalities( + problem, + [x], + vranges, + ) + assert len(solution.variables) == 1 + assert len(solution.ranges) == 0 + assert len(solution.relations) == 3 + + if __name__ == "__main__": tvm.testing.main()